[TorchToArith] Add a lowering for `torch.add.float_int` (#3594)

pull/3607/head
zjgarvey 2024-08-07 09:55:27 -07:00 committed by GitHub
parent a51b4e014a
commit 8d95fe9eeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 3 deletions

View File

@ -72,8 +72,11 @@ public:
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
adaptor.getB());
Value a = adaptor.getA();
Value b = adaptor.getB();
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
return success();
}
};
@ -255,6 +258,25 @@ public:
};
} // namespace
namespace {
template <typename AtenOp>
class ConvertAtenScalarArithOp : public OpConversionPattern<AtenOp> {
public:
using OpConversionPattern<AtenOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result =
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
public:
@ -444,9 +466,12 @@ public:
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(

View File

@ -117,6 +117,7 @@ TORCHDYNAMO_XFAIL_SET = {
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
"AddIntModule_basic",
"AddFloatIntModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
@ -339,6 +340,7 @@ TORCHDYNAMO_CRASHING_SET = {
FX_IMPORTER_XFAIL_SET = {
"ReduceAnyDimFloatModule_basic",
"AddFloatIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
@ -855,6 +857,7 @@ STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AddIntModule_basic",
"AddFloatIntModule_basic",
"AliasModule_basic",
"TrueFalseOrBoolOpModule_basic",
"AllBoolFalseModule_basic",
@ -2100,6 +2103,7 @@ LTC_XFAIL_SET = {
"_ConvolutionDeprecated2DDeterministicModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"AddIntModule_basic",
"AddFloatIntModule_basic",
"ArangeStartOutViewModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
@ -2288,6 +2292,7 @@ ONNX_XFAIL_SET = {
"AdaptiveMaxPool3dStatic_basic",
"AddCDivModule_basic",
"AddIntModule_basic",
"AddFloatIntModule_basic",
"Add_Module_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
@ -2840,6 +2845,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AdaptiveMaxPool3dStaticWithIndices_basic",
"AdaptiveMaxPool3dStatic_basic",
"AddIntModule_basic",
"AddFloatIntModule_basic",
"Add_MixPModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
@ -3609,6 +3615,7 @@ ONNX_TOSA_XFAIL_SET = {
"AdaptiveMaxPool3dStatic_basic",
"AddCDivModule_basic",
"AddIntModule_basic",
"AddFloatIntModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"Add_MixPModule_basic",

View File

@ -36,6 +36,30 @@ def AddIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class AddFloatIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([], torch.float32, True),
([], torch.int64, True),
]
)
def forward(self, lhs, rhs):
return float(lhs) + int(rhs)
@register_test_case(module_factory=lambda: AddFloatIntModule())
def AddFloatIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(), tu.randint(low=-100, high=100))
# ==============================================================================
class SubIntModule(torch.nn.Module):
def __init__(self):
super().__init__()