mirror of https://github.com/llvm/torch-mlir
[StableHlo] Fix AtenWhereSelfOp convert rule (#2093)
* fix whereself convert rule * use int to test promotion * add dynamo failing test --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/2064/head
parent
eaaaeb6ff1
commit
fc62b8e9ab
|
@ -54,6 +54,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
"ElementwiseWhereScalarSelfModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
# %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list<tensor>, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
|
@ -267,6 +269,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseNotInt64Module_basic",
|
||||
"ElementwiseBitwiseNotInt32Module_basic",
|
||||
|
|
|
@ -606,6 +606,12 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
Value cond = adaptor.getCondition();
|
||||
Value other = adaptor.getOther();
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
// promote self and other types
|
||||
self = hlo::promoteType(rewriter, self, outType);
|
||||
other = hlo::promoteType(rewriter, other, outType);
|
||||
|
||||
if (failed(
|
||||
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
|
||||
return op.emitError("failed broadcast self and condition ranks");
|
||||
|
|
|
@ -227,6 +227,29 @@ def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarOtherStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float64, True),
|
||||
([4, 5], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, b, 8)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherStaticModule())
|
||||
def ElementwiseWhereScalarOtherStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -246,6 +269,28 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
|||
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarSelfStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float64, True),
|
||||
([4, 5], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, 4.0, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfStaticModule())
|
||||
def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue