[StableHlo] Fix AtenWhereSelfOp convert rule (#2093)

* fix whereself convert rule

* use int to test promotion

* add dynamo failing test

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/2064/head
Zhekun Zhang 2023-05-05 15:21:55 -07:00 committed by GitHub
parent eaaaeb6ff1
commit fc62b8e9ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 0 deletions

View File

@ -54,6 +54,8 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
# %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list<tensor>, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
@ -267,6 +269,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseNotInt32Module_basic",

View File

@ -606,6 +606,12 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
Value cond = adaptor.getCondition();
Value other = adaptor.getOther();
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
// promote self and other types
self = hlo::promoteType(rewriter, self, outType);
other = hlo::promoteType(rewriter, other, outType);
if (failed(
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
return op.emitError("failed broadcast self and condition ranks");

View File

@ -227,6 +227,29 @@ def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseWhereScalarOtherStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
([4, 5], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, b, 8)
@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherStaticModule())
def ElementwiseWhereScalarOtherStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
# ==============================================================================
class ElementwiseWhereScalarSelfModule(torch.nn.Module):
def __init__(self):
@ -246,6 +269,28 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
# ==============================================================================
class ElementwiseWhereScalarSelfStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
([4, 5], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, 4.0, b)
@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfStaticModule())
def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
# ==============================================================================