mirror of https://github.com/llvm/torch-mlir
add aten.add.int lowering in TorchToStd
parent
7616d28ce1
commit
03fdf56f21
|
@ -40,6 +40,7 @@ from . import reduction
|
|||
from . import argmax
|
||||
from . import matmul
|
||||
from . import view
|
||||
from . import scalar
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
|
||||
class AddIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return int(lhs)+int(rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AddIntModule())
|
||||
def AddIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
|
@ -44,6 +44,19 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> {
|
||||
public:
|
||||
|
@ -129,6 +142,8 @@ public:
|
|||
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenAddIntOp>();
|
||||
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -232,7 +232,8 @@ public:
|
|||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp>(op)) {
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>(
|
||||
op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -426,6 +427,8 @@ public:
|
|||
return visitNumToTensorOp(numToTensorOp);
|
||||
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
|
||||
return visitAtenAddCLikeOp(op, operands);
|
||||
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
|
||||
return visitBinaryScalarOp(scalarOp);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -528,6 +531,8 @@ private:
|
|||
ChangeResult
|
||||
visitAtenEmbeddingOp(AtenEmbeddingOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
|
||||
|
||||
ChangeResult
|
||||
visitAtenBmmOp(AtenBmmOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
@ -587,6 +592,13 @@ static ResultTypeState updateResultTypeState(ValueKnowledge *tensor,
|
|||
return new_state;
|
||||
}
|
||||
|
||||
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
|
||||
ResultTypeState state = {};
|
||||
for (const Type &scalarType : scalarTypes)
|
||||
state = updateResultTypeState(scalarType, state);
|
||||
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
|
||||
}
|
||||
|
||||
// Returns most generic type Type() if the tensor dtype is unknown.
|
||||
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
|
||||
if (!tensor->dtype)
|
||||
|
@ -1086,6 +1098,14 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
|
||||
// corresponds to one dim of the tensor.
|
||||
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
||||
|
|
|
@ -74,3 +74,14 @@ func @torch.constant.int() -> !torch.int {
|
|||
%int1 = torch.constant.int 1
|
||||
return %int1 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
||||
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||
// CHECK: %[[INT:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
|
||||
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
|
||||
// CHECK: return %[[INT:.*]] : !torch.int
|
||||
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
||||
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue