mirror of https://github.com/llvm/torch-mlir
[torch] Supporting `torch.aten.mul.float` lowering to `arith` (#2833)
Simple missing scalar operation for multiply floats was missing.pull/2870/head
parent
e3faef5224
commit
041a54ae0c
|
@ -443,9 +443,11 @@ public:
|
|||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenSubFloatOp>();
|
||||
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenDivIntOp>();
|
||||
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDivFloatOp>();
|
||||
|
|
|
@ -100,6 +100,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
'AtenSubFloatModule_basic',
|
||||
'AtenMulFloatModule_basic',
|
||||
'BoolFloatFalseModule_basic',
|
||||
'BoolFloatTrueModule_basic',
|
||||
'CeilFloatModule_basic',
|
||||
|
@ -109,6 +110,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'GtFloatIntModule_basic',
|
||||
'NeFloatIntModule_basic',
|
||||
'SubFloatModule_basic',
|
||||
'MulFloatModule_basic',
|
||||
'TensorToFloatZeroRank_basic',
|
||||
'TensorToFloat_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
|
@ -1489,6 +1491,7 @@ LTC_XFAIL_SET = {
|
|||
"SliceStartEqEndModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
|
|
|
@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MulFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float64, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) * float(rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MulFloatModule())
|
||||
def MulFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue