From f7c7bb800c94faf0f0811fbe66b7147f3842bfbf Mon Sep 17 00:00:00 2001 From: Qiang Fu Date: Wed, 23 Mar 2022 16:35:43 -0400 Subject: [PATCH] Add non-default dtype support for a few elementwise math ops. (#687) * fix type inference * fix Torch2Linalg conversion * add test cases --- e2e_testing/torchscript/elementwise.py | 274 ++++++++++++++++++ e2e_testing/torchscript/xfail_sets.py | 1 + .../TorchToLinalg/Uncategorized.cpp | 93 ++++-- lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp | 17 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 50 ++-- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 16 +- .../jit_ir/build_tools/shape_lib_gen.py | 9 + 7 files changed, 398 insertions(+), 62 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 1c6b701d6..94e18dbec 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -37,6 +37,25 @@ def ElementwiseUnaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseUnaryIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.tanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseUnaryIntModule()) +def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseBinaryModule(torch.nn.Module): def __init__(self): super().__init__() @@ -282,6 +301,25 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSigmoidIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.sigmoid(x) + + +@register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule()) +def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseMinimumModule(torch.nn.Module): def __init__(self): super().__init__() @@ -545,6 +583,25 @@ def ElementwiseLogModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLogIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.log(a) + + +@register_test_case(module_factory=lambda: ElementwiseLogIntModule()) +def ElementwiseLogIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseErfModule(torch.nn.Module): def __init__(self): super().__init__() @@ -564,6 +621,25 @@ def ElementwiseErfModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseErfIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.erf(a) + + +@register_test_case(module_factory=lambda: ElementwiseErfIntModule()) +def ElementwiseErfIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseSqrtModule(torch.nn.Module): def __init__(self): @@ -585,6 +661,26 @@ def ElementwiseSqrtModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSqrtIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + + def forward(self, a): + return torch.sqrt(a) + + +@register_test_case(module_factory=lambda: ElementwiseSqrtIntModule()) +def ElementwiseSqrtIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseFloorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -699,6 +795,25 @@ def ElementwiseLog2Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLog2IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.log2(a) + + +@register_test_case(module_factory=lambda: ElementwiseLog2IntModule()) +def ElementwiseLog2IntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseRsqrtModule(torch.nn.Module): def __init__(self): super().__init__() @@ -719,6 +834,26 @@ def ElementwiseRsqrtModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRsqrtIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + + def forward(self, a): + return torch.rsqrt(a) + + +@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule()) +def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseAbsModule(torch.nn.Module): def __init__(self): super().__init__() @@ -757,6 +892,25 @@ def ElementwiseReciprocalModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseReciprocalIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + + def forward(self, a): + return torch.reciprocal(a) + + +@register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule()) +def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (4,), dtype=torch.int32)) + +# ============================================================================== + class ElementwiseDivScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -949,3 +1103,123 @@ class ElementwiseCloneContiguousModule(torch.nn.Module): def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) +# ============================================================================== + +class ElementwiseExpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.exp(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpModule()) +def ElementwiseExpModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseExpIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + + def forward(self, a): + return torch.exp(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpIntModule()) +def ElementwiseExpIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + + +# ============================================================================== + +class ElementwiseSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.sin(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinModule()) +def ElementwiseSinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseSinIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + + def forward(self, a): + return torch.sin(a) + + +@register_test_case(module_factory=lambda: ElementwiseSinIntModule()) +def ElementwiseSinIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + +# ============================================================================== + +class ElementwiseCosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.cos(a) + + +@register_test_case(module_factory=lambda: ElementwiseCosModule()) +def ElementwiseCosModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseCosIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + + def forward(self, a): + return torch.cos(a) + + +@register_test_case(module_factory=lambda: ElementwiseCosIntModule()) +def ElementwiseCosIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 7220b8a4a..72fe66470 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -34,6 +34,7 @@ TOSA_PASS_SET = { "ElementwiseUnaryModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", "ElementwiseReluModule_basic", "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d1fedf990..77d170d55 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -189,25 +189,60 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { return buildNormalCdf(b, loc, x, zero, one); } +template +static Value createCalculationForMathOpWithDtypeConversion( + OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) { + Type dtype = converter->convertType(op->getResult(0).getType()) + .template cast() + .getElementType(); + Location loc = op->getLoc(); + Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); + return b.create(loc, arg); +} + static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { - if (isa(op)) - return b.create(loc, payloadArgs[0]); - if (isa(op)) - return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)) - return b.create(loc, payloadArgs[0]); - if (isa(op)) - return b.create(loc, payloadArgs[0]); - if (isa(op)) - return b.create(loc, payloadArgs[0]); - if (isa(op)) - return b.create(loc, payloadArgs[0]); + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.memory_format().getType().isa() && @@ -235,14 +270,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } - if (isa(op)) - return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { - Type elementType = payloadArgs[0].getType(); - auto one = b.create(loc, FloatAttr::get(elementType, 1)); - auto negate = b.create(loc, payloadArgs[0]); + auto negate = createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + auto one = + b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); return b.create(loc, one, added); @@ -763,26 +797,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, self, other); } if (auto reciprocal = dyn_cast(op)) { - if (!reciprocal.getType() - .cast() - .getDtype() - .isa()) { - reciprocal.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - - Type elementType = payloadArgs[0].getType(); + Type dtype = converter->convertType(reciprocal.getType()) + .cast() + .getElementType(); + Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Type elementType = arg.getType(); // assert(element != 0) auto zero = b.create(loc, FloatAttr::get(elementType, 0.0)); - auto pred = b.create(loc, arith::CmpFPredicate::ONE, - payloadArgs[0], zero); + auto pred = + b.create(loc, arith::CmpFPredicate::ONE, arg, zero); b.create( loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); auto one = b.create(loc, FloatAttr::get(elementType, 1.0)); - return b.create(loc, one, payloadArgs[0]); + return b.create(loc, one, arg); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: @@ -871,7 +901,7 @@ public: AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp>(op)) + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1545,7 +1575,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(); + AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp b/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp index 0cc558bfb..c76947bfd 100644 --- a/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp +++ b/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp @@ -62,19 +62,16 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser, void Torch::printDefaultTorchOp(OpAsmPrinter &p, Operation *op, int numOperands, int numResults) { - p << ' '; - llvm::interleaveComma(op->getOperands(), p); - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{}); - p << " : "; if (numOperands > 0) { p << ' '; + llvm::interleaveComma(op->getOperands(), p); + } + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{}); + p << " : "; + if (numOperands > 0) llvm::interleaveComma(op->getOperandTypes(), p); - } - if (numOperands > 0 && numResults > 0) { + if (numOperands > 0 && numResults > 0) p << " -> "; - } - if (numResults > 0) { - p << ' '; + if (numResults > 0) llvm::interleaveComma(op->getResultTypes(), p); - } } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b46ecc89b..a6606612b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -496,28 +496,25 @@ ChangeResult TypeAnalyzer::visitOperation( } // Take dtype from first operand. - if (isa(op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); @@ -525,6 +522,21 @@ ChangeResult TypeAnalyzer::visitOperation( return incorporateKnowledge(op->getResult(0), knowledge); } + // Dtype is always float32, except for float64 and nullptr. + if (isa(op)) { + ValueKnowledge knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + Type dtype = operands[0]->getValue().dtype; + if (dtype) { + knowledge.dtype = Float32Type::get(op->getContext()); + if (dtype.isa()) + knowledge.dtype = dtype; + } + return incorporateKnowledge(op->getResult(0), knowledge); + } + // Take dtype from second operand. if (isa(op)) { auto self = operands[1]->getValue(); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 4142dca80..c5a64e344 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -72,6 +72,18 @@ module { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.exp"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } + func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } + func @"__torch_mlir_shape_fn.aten.cos"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list @@ -1859,7 +1871,7 @@ module { } torch.prim.If.yield } - %2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float + %2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float %3 = torch.operator "aten.div.float"(%2, %arg2) : (!torch.float, !torch.float) -> !torch.float %4 = torch.operator "aten.ceil.float"(%3) : (!torch.float) -> !torch.int %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list @@ -1891,7 +1903,7 @@ module { torch.prim.RaiseException %str, %none : !torch.str, !torch.none torch.prim.If.yield } - %2 = torch.operator "aten.sub.float"(%arg1, %arg0) : (!torch.float, !torch.float) -> !torch.float + %2 = torch.aten.sub.float %arg1, %arg0 : !torch.float, !torch.float -> !torch.float %3 = torch.operator "aten.ceil.float"(%2) : (!torch.float) -> !torch.int %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list return %4 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 111146146..22da34c9c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -303,6 +303,15 @@ def aten〇hardswish(self: List[int]) -> List[int]: def aten〇silu(self: List[int]) -> List[int]: return upstream_shape_helpers.unary(self) +def aten〇exp(self: List[int]) -> List[int]: + return upstream_shape_helpers.unary(self) + +def aten〇sin(self: List[int]) -> List[int]: + return upstream_shape_helpers.unary(self) + +def aten〇cos(self: List[int]) -> List[int]: + return upstream_shape_helpers.unary(self) + def aten〇hardtanh(self: List[int], min_val: float = -1, max_val: float = 1) -> List[int]: return upstream_shape_helpers.unary(self)