From 8d95fe9eebcfcb3617580c28e8f49dd9b62b743e Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 7 Aug 2024 09:55:27 -0700 Subject: [PATCH] [TorchToArith] Add a lowering for `torch.add.float_int` (#3594) --- lib/Conversion/TorchToArith/TorchToArith.cpp | 31 +++++++++++++++++-- projects/pt1/e2e_testing/xfail_sets.py | 7 +++++ .../torch_mlir_e2e_test/test_suite/scalar.py | 24 ++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index ec7963a14..a1af190e4 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -72,8 +72,11 @@ public: matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), - adaptor.getB()); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + if (llvm::is_one_of::value) + b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); + rewriter.template replaceOpWithNewOp(op, a, b); return success(); } }; @@ -255,6 +258,25 @@ public: }; } // namespace +namespace { +template +class ConvertAtenScalarArithOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + Value result = + convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenAddOp : public OpConversionPattern { public: @@ -444,9 +466,12 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7276d4435..171630ff9 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -117,6 +117,7 @@ TORCHDYNAMO_XFAIL_SET = { # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "AddIntModule_basic", + "AddFloatIntModule_basic", "AtenIntTensorCharDtypeModule_basic", "BoolIntFalseModule_basic", "BoolIntTrueModule_basic", @@ -339,6 +340,7 @@ TORCHDYNAMO_CRASHING_SET = { FX_IMPORTER_XFAIL_SET = { "ReduceAnyDimFloatModule_basic", + "AddFloatIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -855,6 +857,7 @@ STABLEHLO_PASS_SET = { "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "AliasModule_basic", "TrueFalseOrBoolOpModule_basic", "AllBoolFalseModule_basic", @@ -2100,6 +2103,7 @@ LTC_XFAIL_SET = { "_ConvolutionDeprecated2DDeterministicModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "ArangeStartOutViewModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", @@ -2288,6 +2292,7 @@ ONNX_XFAIL_SET = { "AdaptiveMaxPool3dStatic_basic", "AddCDivModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "Add_Module_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", @@ -2840,6 +2845,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "AdaptiveMaxPool3dStaticWithIndices_basic", "AdaptiveMaxPool3dStatic_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "Add_MixPModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", @@ -3609,6 +3615,7 @@ ONNX_TOSA_XFAIL_SET = { "AdaptiveMaxPool3dStatic_basic", "AddCDivModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "AddSizeIntModule_basic", "AddSizeIntNegDimModule_basic", "Add_MixPModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 3dacb9872..28b3a6f36 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -36,6 +36,30 @@ def AddIntModule_basic(module, tu: TestUtils): # ============================================================================== +class AddFloatIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ([], torch.int64, True), + ] + ) + def forward(self, lhs, rhs): + return float(lhs) + int(rhs) + + +@register_test_case(module_factory=lambda: AddFloatIntModule()) +def AddFloatIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(), tu.randint(low=-100, high=100)) + + +# ============================================================================== + + class SubIntModule(torch.nn.Module): def __init__(self): super().__init__()