[torch] Supporting `torch.aten.mul.float` lowering to `arith` (#2833)

Simple missing scalar operation for multiply floats was missing.
pull/2870/head
Rob Suderman 2024-02-05 16:23:04 -08:00 committed by GitHub
parent e3faef5224
commit 041a54ae0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 1 deletions

View File

@ -443,9 +443,11 @@ public:
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
target.addIllegalOp<AtenSubFloatOp>();
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulFloatOp, arith::MulFOp>>(
typeConverter, context);
target.addIllegalOp<AtenDivIntOp>();
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
target.addIllegalOp<AtenDivFloatOp>();

View File

@ -100,6 +100,7 @@ TORCHDYNAMO_XFAIL_SET = {
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
'AtenSubFloatModule_basic',
'AtenMulFloatModule_basic',
'BoolFloatFalseModule_basic',
'BoolFloatTrueModule_basic',
'CeilFloatModule_basic',
@ -109,6 +110,7 @@ TORCHDYNAMO_XFAIL_SET = {
'GtFloatIntModule_basic',
'NeFloatIntModule_basic',
'SubFloatModule_basic',
'MulFloatModule_basic',
'TensorToFloatZeroRank_basic',
'TensorToFloat_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
@ -1489,6 +1491,7 @@ LTC_XFAIL_SET = {
"SliceStartEqEndModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"MulFloatModule_basic",
"SubIntModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic",

View File

@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double(), tu.rand().double())
# ==============================================================================
class MulFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float64, True),
([], torch.float64, True),
])
def forward(self, lhs, rhs):
return float(lhs) * float(rhs)
@register_test_case(module_factory=lambda: MulFloatModule())
def MulFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand().double(), tu.rand().double())
# ==============================================================================