mirror of https://github.com/llvm/torch-mlir
Add non-default dtype support for a few elementwise math ops. (#687)
* fix type inference * fix Torch2Linalg conversion * add test casespull/641/head
parent
fe8ac57e6d
commit
f7c7bb800c
|
@ -37,6 +37,25 @@ def ElementwiseUnaryModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseUnaryIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.tanh(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseUnaryIntModule())
|
||||||
|
def ElementwiseUnaryIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseBinaryModule(torch.nn.Module):
|
class ElementwiseBinaryModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -282,6 +301,25 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseSigmoidIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule())
|
||||||
|
def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseMinimumModule(torch.nn.Module):
|
class ElementwiseMinimumModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -545,6 +583,25 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseLogIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.log(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseLogIntModule())
|
||||||
|
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseErfModule(torch.nn.Module):
|
class ElementwiseErfModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -564,6 +621,25 @@ def ElementwiseErfModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseErfIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ops.aten.erf(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseErfIntModule())
|
||||||
|
def ElementwiseErfIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseSqrtModule(torch.nn.Module):
|
class ElementwiseSqrtModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -585,6 +661,26 @@ def ElementwiseSqrtModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseSqrtIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sqrt(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseSqrtIntModule())
|
||||||
|
def ElementwiseSqrtIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseFloorModule(torch.nn.Module):
|
class ElementwiseFloorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -699,6 +795,25 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseLog2IntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.log2(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseLog2IntModule())
|
||||||
|
def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseRsqrtModule(torch.nn.Module):
|
class ElementwiseRsqrtModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -719,6 +834,26 @@ def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseRsqrtIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.rsqrt(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule())
|
||||||
|
def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseAbsModule(torch.nn.Module):
|
class ElementwiseAbsModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -757,6 +892,25 @@ def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseReciprocalIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.reciprocal(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule())
|
||||||
|
def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (4,), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseDivScalarModule(torch.nn.Module):
|
class ElementwiseDivScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -949,3 +1103,123 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
||||||
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 4))
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseExpModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.exp(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseExpModule())
|
||||||
|
def ElementwiseExpModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseExpIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.exp(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseExpIntModule())
|
||||||
|
def ElementwiseExpIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseSinModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sin(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseSinModule())
|
||||||
|
def ElementwiseSinModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseSinIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sin(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseSinIntModule())
|
||||||
|
def ElementwiseSinIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseCosModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.cos(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseCosModule())
|
||||||
|
def ElementwiseCosModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseCosIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.cos(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseCosIntModule())
|
||||||
|
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||||
|
|
|
@ -34,6 +34,7 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseUnaryModule_basic",
|
"ElementwiseUnaryModule_basic",
|
||||||
"ElementwiseBinaryModule_basic",
|
"ElementwiseBinaryModule_basic",
|
||||||
"ElementwiseSigmoidModule_basic",
|
"ElementwiseSigmoidModule_basic",
|
||||||
|
"ElementwiseExpModule_basic",
|
||||||
"ElementwiseReluModule_basic",
|
"ElementwiseReluModule_basic",
|
||||||
"ElementwiseFloorModule_basic",
|
"ElementwiseFloorModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
|
|
|
@ -189,25 +189,60 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
||||||
return buildNormalCdf(b, loc, x, zero, one);
|
return buildNormalCdf(b, loc, x, zero, one);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename MathOpTy>
|
||||||
|
static Value createCalculationForMathOpWithDtypeConversion(
|
||||||
|
OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) {
|
||||||
|
Type dtype = converter->convertType(op->getResult(0).getType())
|
||||||
|
.template cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
|
||||||
|
return b.create<MathOpTy>(loc, arg);
|
||||||
|
}
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
if (isa<AtenTanhOp>(op))
|
|
||||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
|
||||||
if (isa<AtenExpOp>(op))
|
|
||||||
return b.create<math::ExpOp>(loc, payloadArgs[0]);
|
|
||||||
if (isa<AtenFloorOp>(op))
|
if (isa<AtenFloorOp>(op))
|
||||||
return b.create<math::FloorOp>(loc, payloadArgs[0]);
|
return b.create<math::FloorOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenCeilOp>(op))
|
if (isa<AtenCeilOp>(op))
|
||||||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenLogOp>(op))
|
if (isa<AtenTanhOp>(op)) {
|
||||||
return b.create<math::LogOp>(loc, payloadArgs[0]);
|
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
||||||
if (isa<AtenErfOp>(op))
|
b, converter, payloadArgs[0], op);
|
||||||
return b.create<math::ErfOp>(loc, payloadArgs[0]);
|
}
|
||||||
if (isa<AtenSqrtOp>(op))
|
if (isa<AtenExpOp>(op)) {
|
||||||
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
|
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
|
||||||
if (isa<AtenRsqrtOp>(op))
|
b, converter, payloadArgs[0], op);
|
||||||
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
|
}
|
||||||
|
if (isa<AtenLogOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenLog2Op>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenErfOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenSqrtOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenRsqrtOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenSinOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
|
if (isa<AtenCosOp>(op)) {
|
||||||
|
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
|
||||||
|
b, converter, payloadArgs[0], op);
|
||||||
|
}
|
||||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!clone.memory_format().getType().isa<Torch::NoneType>() &&
|
if (!clone.memory_format().getType().isa<Torch::NoneType>() &&
|
||||||
|
@ -235,14 +270,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
if (isa<AtenLog2Op>(op))
|
|
||||||
return b.create<math::Log2Op>(loc, payloadArgs[0]);
|
|
||||||
if (isa<AtenAbsOp>(op))
|
if (isa<AtenAbsOp>(op))
|
||||||
return b.create<math::AbsOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenSigmoidOp>(op)) {
|
if (isa<AtenSigmoidOp>(op)) {
|
||||||
Type elementType = payloadArgs[0].getType();
|
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
||||||
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
b, converter, payloadArgs[0], op);
|
||||||
auto negate = b.create<arith::NegFOp>(loc, payloadArgs[0]);
|
auto one =
|
||||||
|
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
|
||||||
auto exp = b.create<math::ExpOp>(loc, negate);
|
auto exp = b.create<math::ExpOp>(loc, negate);
|
||||||
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
||||||
return b.create<arith::DivFOp>(loc, one, added);
|
return b.create<arith::DivFOp>(loc, one, added);
|
||||||
|
@ -763,26 +797,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::DivFOp>(loc, self, other);
|
return b.create<arith::DivFOp>(loc, self, other);
|
||||||
}
|
}
|
||||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||||
if (!reciprocal.getType()
|
Type dtype = converter->convertType(reciprocal.getType())
|
||||||
.cast<ValueTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getDtype()
|
.getElementType();
|
||||||
.isa<mlir::FloatType>()) {
|
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
reciprocal.emitError("unimplemented: non-floating point dtype");
|
Type elementType = arg.getType();
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type elementType = payloadArgs[0].getType();
|
|
||||||
// assert(element != 0)
|
// assert(element != 0)
|
||||||
auto zero =
|
auto zero =
|
||||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
|
auto pred =
|
||||||
payloadArgs[0], zero);
|
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, arg, zero);
|
||||||
b.create<cf::AssertOp>(
|
b.create<cf::AssertOp>(
|
||||||
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
|
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
|
||||||
|
|
||||||
auto one =
|
auto one =
|
||||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
||||||
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
|
return b.create<arith::DivFOp>(loc, one, arg);
|
||||||
}
|
}
|
||||||
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
|
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
|
||||||
// The approach used here is as follows:
|
// The approach used here is as follows:
|
||||||
|
@ -871,7 +901,7 @@ public:
|
||||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp>(op))
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1545,7 +1575,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>();
|
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||||
|
AtenSinOp, AtenCosOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||||
|
|
|
@ -62,19 +62,16 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
|
||||||
|
|
||||||
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
|
void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands,
|
||||||
int numResults) {
|
int numResults) {
|
||||||
p << ' ';
|
|
||||||
llvm::interleaveComma(op->getOperands(), p);
|
|
||||||
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
|
|
||||||
p << " : ";
|
|
||||||
if (numOperands > 0) {
|
if (numOperands > 0) {
|
||||||
p << ' ';
|
p << ' ';
|
||||||
|
llvm::interleaveComma(op->getOperands(), p);
|
||||||
|
}
|
||||||
|
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{});
|
||||||
|
p << " : ";
|
||||||
|
if (numOperands > 0)
|
||||||
llvm::interleaveComma(op->getOperandTypes(), p);
|
llvm::interleaveComma(op->getOperandTypes(), p);
|
||||||
}
|
if (numOperands > 0 && numResults > 0)
|
||||||
if (numOperands > 0 && numResults > 0) {
|
|
||||||
p << " -> ";
|
p << " -> ";
|
||||||
}
|
if (numResults > 0)
|
||||||
if (numResults > 0) {
|
|
||||||
p << ' ';
|
|
||||||
llvm::interleaveComma(op->getResultTypes(), p);
|
llvm::interleaveComma(op->getResultTypes(), p);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -496,28 +496,25 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take dtype from first operand.
|
// Take dtype from first operand.
|
||||||
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenTanhOp,
|
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
|
||||||
AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenCeilOp,
|
AtenReluOp, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
|
||||||
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
|
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
||||||
AtenSigmoidOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op,
|
||||||
AtenFill_ScalarOp, AtenDetachOp, AtenReciprocalOp,
|
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenNegOp, AtenFloorOp,
|
||||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
|
Aten_SoftmaxBackwardDataOp, AtenDropoutOp, AtenTanhBackwardOp,
|
||||||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
|
||||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
|
||||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, ValsemVariantAtenUniformOp,
|
|
||||||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||||
AtenHardswishOp, AtenErfOp, AtenSiluOp, AtenHardtanhOp,
|
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||||
AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp,
|
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||||
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
|
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||||
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
|
Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp,
|
||||||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
||||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||||
AtenConstantPadNdOp, AtenIndexTensorOp,
|
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
|
||||||
ValsemVariantAtenCopyOp>(op)) {
|
ValsemVariantAtenCopyOp>(op)) {
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
@ -525,6 +522,21 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dtype is always float32, except for float64 and nullptr.
|
||||||
|
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||||
|
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
||||||
|
AtenErfOp>(op)) {
|
||||||
|
ValueKnowledge knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
Type dtype = operands[0]->getValue().dtype;
|
||||||
|
if (dtype) {
|
||||||
|
knowledge.dtype = Float32Type::get(op->getContext());
|
||||||
|
if (dtype.isa<Float64Type>())
|
||||||
|
knowledge.dtype = dtype;
|
||||||
|
}
|
||||||
|
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
// Take dtype from second operand.
|
// Take dtype from second operand.
|
||||||
if (isa<AtenNllLossBackwardOp>(op)) {
|
if (isa<AtenNllLossBackwardOp>(op)) {
|
||||||
auto self = operands[1]->getValue();
|
auto self = operands[1]->getValue();
|
||||||
|
|
|
@ -72,6 +72,18 @@ module {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.exp"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.cos"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
@ -1859,7 +1871,7 @@ module {
|
||||||
}
|
}
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
}
|
}
|
||||||
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float
|
%2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
|
||||||
%3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float
|
%3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float
|
||||||
%4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int
|
%4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int
|
||||||
%5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>
|
%5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>
|
||||||
|
@ -1891,7 +1903,7 @@ module {
|
||||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||||
torch.prim.If.yield
|
torch.prim.If.yield
|
||||||
}
|
}
|
||||||
%2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float
|
%2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float
|
||||||
%3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int
|
%3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int
|
||||||
%4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>
|
%4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list<int>
|
||||||
return %4 : !torch.list<int>
|
return %4 : !torch.list<int>
|
||||||
|
|
|
@ -303,6 +303,15 @@ def aten〇hardswish(self: List[int]) -> List[int]:
|
||||||
def aten〇silu(self: List[int]) -> List[int]:
|
def aten〇silu(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_helpers.unary(self)
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
def aten〇exp(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
def aten〇sin(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
def aten〇cos(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
def aten〇hardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]:
|
def aten〇hardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]:
|
||||||
return upstream_shape_helpers.unary(self)
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue