mirror of https://github.com/llvm/torch-mlir
[TorchToArith] Add a lowering for `torch.add.float_int` (#3594)
parent
a51b4e014a
commit
8d95fe9eeb
|
@ -72,8 +72,11 @@ public:
|
|||
matchAndRewrite(AtenOp op,
|
||||
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
|
||||
adaptor.getB());
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
|
||||
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
|
||||
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -255,6 +258,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOp>
|
||||
class ConvertAtenScalarArithOp : public OpConversionPattern<AtenOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOp op,
|
||||
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
Value result =
|
||||
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
|
||||
public:
|
||||
|
@ -444,9 +466,12 @@ public:
|
|||
target.addIllegalOp<AtenAddOp>();
|
||||
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
|
||||
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
|
||||
AtenMulIntOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
|
||||
|
|
|
@ -117,6 +117,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
|
||||
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
|
@ -339,6 +340,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
|
||||
FX_IMPORTER_XFAIL_SET = {
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
|
@ -855,6 +857,7 @@ STABLEHLO_PASS_SET = {
|
|||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"AliasModule_basic",
|
||||
"TrueFalseOrBoolOpModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
|
@ -2100,6 +2103,7 @@ LTC_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
|
@ -2288,6 +2292,7 @@ ONNX_XFAIL_SET = {
|
|||
"AdaptiveMaxPool3dStatic_basic",
|
||||
"AddCDivModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"Add_Module_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
|
@ -2840,6 +2845,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AdaptiveMaxPool3dStaticWithIndices_basic",
|
||||
"AdaptiveMaxPool3dStatic_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"Add_MixPModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
|
@ -3609,6 +3615,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"AdaptiveMaxPool3dStatic_basic",
|
||||
"AddCDivModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"AddSizeIntModule_basic",
|
||||
"AddSizeIntNegDimModule_basic",
|
||||
"Add_MixPModule_basic",
|
||||
|
|
|
@ -36,6 +36,30 @@ def AddIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class AddFloatIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, lhs, rhs):
|
||||
return float(lhs) + int(rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AddFloatIntModule())
|
||||
def AddFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SubIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue