diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index a37f31376..1d3d4b36a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -774,13 +774,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(divTensorMode.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - divTensorMode.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value div = b.create(loc, lhs, rhs); + Value div; + if (dtype.isa()) + div = b.create(loc, lhs, rhs); + else { + if (dtype.isUnsignedInteger()) + div = b.create(loc, lhs, rhs); + else + div = b.create(loc, lhs, rhs); + } if (divTensorMode.getRoundingMode().getType().isa()) return div; @@ -794,17 +798,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. - Value ceil = b.create(loc, div); - Value floor = b.create(loc, div); - Value cstZero = b.create(loc, b.getZeroAttr(dtype)); - Value pred = - b.create(loc, arith::CmpFPredicate::ULT, div, cstZero); - return b.create(loc, pred, ceil, floor); + if (dtype.isa()) { + Value ceil = b.create(loc, div); + Value floor = b.create(loc, div); + Value cstZero = b.create(loc, b.getZeroAttr(dtype)); + Value pred = b.create(loc, arith::CmpFPredicate::ULT, + div, cstZero); + return b.create(loc, pred, ceil, floor); + } else + return div; } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - return b.create(loc, div); + if (dtype.isa()) + return b.create(loc, div); + else if (!dtype.isUnsignedInteger()) { + Type defaultIntToFloatType = b.getF64Type(); + lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType); + rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType); + div = b.create(loc, lhs, rhs); + Value floor = b.create(loc, div); + Value convert = convertScalarToDtype(b, loc, floor, dtype); + return convert; + } else { + return div; + } } divTensorMode.emitError("invalid rounding mode"); return nullptr; diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 22743e6a9..a55ca6a24 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -445,7 +445,8 @@ public: return rewriter.notifyMatchFailure( op, "only support constant str rounding mode"); - if (roundingMode == "trunc") { + // if trunc and int, do nothing + if (roundingMode == "trunc" && outElemTy.isa()) { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. auto sign = rewriter.create(loc, result); @@ -456,7 +457,20 @@ public: if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - result = rewriter.create(loc, result).getResult(); + if (outElemTy.isa()) + result = rewriter.create(loc, result).getResult(); + else if (!outElemTy.isUnsignedInteger()) { + TensorType defaultIntToFloatType = + outType.cloneWith(outType.getShape(), rewriter.getF64Type()); + lhs = + hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType); + rhs = + hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType); + result = rewriter.create(loc, defaultIntToFloatType, lhs, rhs, + bcastDimensions); + result = rewriter.create(loc, result).getResult(); + result = hlo::promoteType(rewriter, op.getLoc(), result, outType); + } } rewriter.replaceOp(op, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 466f316ef..82342643a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -245,6 +245,10 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode "ElementwiseDivRoundingModeFloorModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseDivRoundingModeFloorStaticModule_basic", + "ElementwiseDivRoundingModeTruncStaticModule_basic", + "ElementwiseDivRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivRoundingModeTruncIntStaticModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", @@ -479,6 +483,10 @@ STABLEHLO_PASS_SET = { "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", "ElementwiseCosModule_basic", + "ElementwiseDivRoundingModeFloorStaticModule_basic", + "ElementwiseDivRoundingModeTruncStaticModule_basic", + "ElementwiseDivRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivRoundingModeTruncIntStaticModule_basic", "ElementwiseErfModule_basic", "ElementwiseExpModule_basic", "ElementwiseFloorIntModule_basic", @@ -2024,7 +2032,10 @@ ONNX_XFAIL_SET = { "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", "ElementwiseCosIntModule_basic", + "ElementwiseDivRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivRoundingModeTruncIntStaticModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseDivRoundingModeTruncStaticModule_basic", "ElementwiseErfIntModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseLogIntModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index e9aa7571b..bfe7979f0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2789,6 +2789,88 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module): def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) +class ElementwiseDivRoundingModeTruncStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4], torch.float32, True), + ([4], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeTruncStaticModule()) +def ElementwiseDivRoundingModeTruncStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) + + +class ElementwiseDivRoundingModeFloorStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ([3, 4], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeFloorStaticModule()) +def ElementwiseDivRoundingModeFloorStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) + +class ElementwiseDivRoundingModeTruncIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int32, True), + ([3, 4], torch.int64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeTruncIntStaticModule()) +def ElementwiseDivRoundingModeTruncIntStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) + + +class ElementwiseDivRoundingModeFloorIntStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int32, True), + ([3, 4], torch.int64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeFloorIntStaticModule()) +def ElementwiseDivRoundingModeFloorIntStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10).type(torch.int32), tu.randint(3, 4, low=1, high=10).type(torch.int64)) + # ==============================================================================