add aten.add.int lowering in TorchToStd

pull/436/head
dan 2021-11-24 00:16:32 +00:00 committed by Yi Zhang
parent 7616d28ce1
commit 03fdf56f21
5 changed files with 77 additions and 1 deletions

View File

@ -40,6 +40,7 @@ from . import reduction
from . import argmax from . import argmax
from . import matmul from . import matmul
from . import view from . import view
from . import scalar
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,29 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
class AddIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
return int(lhs)+int(rhs)
@register_test_case(module_factory=lambda: AddIntModule())
def AddIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))

View File

@ -44,6 +44,19 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> { class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> {
public: public:
@ -129,6 +142,8 @@ public:
target.addIllegalOp<Torch::ConstantIntOp>(); target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter, patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context); context);
target.addIllegalOp<AtenAddIntOp>();
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))
return signalPassFailure(); return signalPassFailure();

View File

@ -232,7 +232,8 @@ public:
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp>(op)) { AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>(
op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
@ -426,6 +427,8 @@ public:
return visitNumToTensorOp(numToTensorOp); return visitNumToTensorOp(numToTensorOp);
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) { } else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
return visitAtenAddCLikeOp(op, operands); return visitAtenAddCLikeOp(op, operands);
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
return visitBinaryScalarOp(scalarOp);
} }
// Otherwise, this is an unknown operation. Just mark all results as // Otherwise, this is an unknown operation. Just mark all results as
@ -528,6 +531,8 @@ private:
ChangeResult ChangeResult
visitAtenEmbeddingOp(AtenEmbeddingOp op, visitAtenEmbeddingOp(AtenEmbeddingOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
ChangeResult ChangeResult
visitAtenBmmOp(AtenBmmOp op, visitAtenBmmOp(AtenBmmOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -587,6 +592,13 @@ static ResultTypeState updateResultTypeState(ValueKnowledge *tensor,
return new_state; return new_state;
} }
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
ResultTypeState state = {};
for (const Type &scalarType : scalarTypes)
state = updateResultTypeState(scalarType, state);
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
}
// Returns most generic type Type() if the tensor dtype is unknown. // Returns most generic type Type() if the tensor dtype is unknown.
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) { static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
if (!tensor->dtype) if (!tensor->dtype)
@ -1086,6 +1098,14 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }
template <typename OpTy>
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
return getLatticeElement(op.getResult()).join(knowledge);
}
// `torch.aten.tensor` get a tensor from a list. Each layer of the list // `torch.aten.tensor` get a tensor from a list. Each layer of the list
// corresponds to one dim of the tensor. // corresponds to one dim of the tensor.
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) { ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {

View File

@ -74,3 +74,14 @@ func @torch.constant.int() -> !torch.int {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
return %int1 : !torch.int return %int1 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
// CHECK: %[[INT:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[INT:.*]] : !torch.int
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}