From fc62b8e9abe2bfa8ec51f1b5edbdfa02a6414353 Mon Sep 17 00:00:00 2001 From: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com> Date: Fri, 5 May 2023 15:21:55 -0700 Subject: [PATCH] [StableHlo] Fix AtenWhereSelfOp convert rule (#2093) * fix whereself convert rule * use int to test promotion * add dynamo failing test --------- Co-authored-by: zhekun.zhang --- e2e_testing/xfail_sets.py | 4 ++ lib/Conversion/TorchToStablehlo/Basic.cpp | 6 +++ .../test_suite/elementwise.py | 45 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 552c5e4c1..2974882c4 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -54,6 +54,8 @@ TORCHDYNAMO_XFAIL_SET = { "ElementwiseWhereScalarModule_basic", "ElementwiseWhereScalarOtherModule_basic", "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", # %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", @@ -267,6 +269,8 @@ STABLEHLO_PASS_SET = { "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenWhereSelfModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index bc25c7e64..929ba7323 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -606,6 +606,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value cond = adaptor.getCondition(); Value other = adaptor.getOther(); + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + // promote self and other types + self = hlo::promoteType(rewriter, self, outType); + other = hlo::promoteType(rewriter, other, outType); + if (failed( broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) return op.emitError("failed broadcast self and condition ranks"); diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 4c732317a..b881ddab7 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -227,6 +227,29 @@ def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseWhereScalarOtherStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float64, True), + ([4, 5], torch.float64, True), + ]) + def forward(self, a, b): + return torch.where(a > 0.5, b, 8) + + +@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherStaticModule()) +def ElementwiseWhereScalarOtherStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) + + +# ============================================================================== + + class ElementwiseWhereScalarSelfModule(torch.nn.Module): def __init__(self): @@ -246,6 +269,28 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module): def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) +# ============================================================================== + + +class ElementwiseWhereScalarSelfStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float64, True), + ([4, 5], torch.float64, True), + ]) + def forward(self, a, b): + return torch.where(a > 0.5, 4.0, b) + + +@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfStaticModule()) +def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) + # ==============================================================================