diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 2f365f40d..3a270b9ed 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -40,6 +40,7 @@ from . import reduction from . import argmax from . import matmul from . import view +from . import scalar def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/scalar.py b/e2e_testing/torchscript/scalar.py new file mode 100644 index 000000000..3b82b5ee6 --- /dev/null +++ b/e2e_testing/torchscript/scalar.py @@ -0,0 +1,29 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + + +class AddIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return int(lhs)+int(rhs) + + +@register_test_case(module_factory=lambda: AddIntModule()) +def AddIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,())) diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index b1b990cba..8a8798d64 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -44,6 +44,19 @@ public: }; } // namespace +namespace { +class ConvertAtenAddIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.a(), adaptor.b()); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenNeIntOp : public OpConversionPattern { public: @@ -129,6 +142,8 @@ public: target.addIllegalOp(); patterns.add>(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8036ff01d..0697a174b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -232,7 +232,8 @@ public: AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, - AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp>(op)) { + AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>( + op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -426,6 +427,8 @@ public: return visitNumToTensorOp(numToTensorOp); } else if (isa(op)) { return visitAtenAddCLikeOp(op, operands); + } else if (auto scalarOp = dyn_cast(op)) { + return visitBinaryScalarOp(scalarOp); } // Otherwise, this is an unknown operation. Just mark all results as @@ -528,6 +531,8 @@ private: ChangeResult visitAtenEmbeddingOp(AtenEmbeddingOp op, ArrayRef *> operands); + template ChangeResult visitBinaryScalarOp(OpTy op); + ChangeResult visitAtenBmmOp(AtenBmmOp op, ArrayRef *> operands); @@ -587,6 +592,13 @@ static ResultTypeState updateResultTypeState(ValueKnowledge *tensor, return new_state; } +static Type getPromotedResultType(ArrayRef scalarTypes) { + ResultTypeState state = {}; + for (const Type &scalarType : scalarTypes) + state = updateResultTypeState(scalarType, state); + return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state)); +} + // Returns most generic type Type() if the tensor dtype is unknown. static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) { if (!tensor->dtype) @@ -1086,6 +1098,14 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) { return getLatticeElement(op.getResult()).join(knowledge); } +template +ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) { + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()}); + return getLatticeElement(op.getResult()).join(knowledge); +} + // `torch.aten.tensor` get a tensor from a list. Each layer of the list // corresponds to one dim of the tensor. ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) { diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index eac71ff36..9b1ca6021 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -74,3 +74,14 @@ func @torch.constant.int() -> !torch.int { %int1 = torch.constant.int 1 return %int1 : !torch.int } + +// CHECK-LABEL: func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[INT:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 +// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]] +// CHECK: return %[[INT:.*]] : !torch.int +func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { + %0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int + return %0 : !torch.int +}