mirror of https://github.com/llvm/torch-mlir
add aten.add.int lowering in TorchToStd
parent
7616d28ce1
commit
03fdf56f21
|
@ -40,6 +40,7 @@ from . import reduction
|
||||||
from . import argmax
|
from . import argmax
|
||||||
from . import matmul
|
from . import matmul
|
||||||
from . import view
|
from . import view
|
||||||
|
from . import scalar
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
|
||||||
|
class AddIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.int64, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return int(lhs)+int(rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AddIntModule())
|
||||||
|
def AddIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(-100, 100,()), torch.randint(-100, 100,()))
|
|
@ -44,6 +44,19 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenAddIntOp : public OpConversionPattern<AtenAddIntOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenAddIntOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, adaptor.a(), adaptor.b());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> {
|
class ConvertAtenNeIntOp : public OpConversionPattern<AtenNeIntOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -129,6 +142,8 @@ public:
|
||||||
target.addIllegalOp<Torch::ConstantIntOp>();
|
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
target.addIllegalOp<AtenAddIntOp>();
|
||||||
|
patterns.add<ConvertAtenAddIntOp>(typeConverter, context);
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -232,7 +232,8 @@ public:
|
||||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
||||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp>(op)) {
|
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>(
|
||||||
|
op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -426,6 +427,8 @@ public:
|
||||||
return visitNumToTensorOp(numToTensorOp);
|
return visitNumToTensorOp(numToTensorOp);
|
||||||
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
|
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
|
||||||
return visitAtenAddCLikeOp(op, operands);
|
return visitAtenAddCLikeOp(op, operands);
|
||||||
|
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
|
||||||
|
return visitBinaryScalarOp(scalarOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, this is an unknown operation. Just mark all results as
|
// Otherwise, this is an unknown operation. Just mark all results as
|
||||||
|
@ -528,6 +531,8 @@ private:
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenEmbeddingOp(AtenEmbeddingOp op,
|
visitAtenEmbeddingOp(AtenEmbeddingOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
template <typename OpTy> ChangeResult visitBinaryScalarOp(OpTy op);
|
||||||
|
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenBmmOp(AtenBmmOp op,
|
visitAtenBmmOp(AtenBmmOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
@ -587,6 +592,13 @@ static ResultTypeState updateResultTypeState(ValueKnowledge *tensor,
|
||||||
return new_state;
|
return new_state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
|
||||||
|
ResultTypeState state = {};
|
||||||
|
for (const Type &scalarType : scalarTypes)
|
||||||
|
state = updateResultTypeState(scalarType, state);
|
||||||
|
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
|
||||||
|
}
|
||||||
|
|
||||||
// Returns most generic type Type() if the tensor dtype is unknown.
|
// Returns most generic type Type() if the tensor dtype is unknown.
|
||||||
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
|
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
|
||||||
if (!tensor->dtype)
|
if (!tensor->dtype)
|
||||||
|
@ -1086,6 +1098,14 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
ChangeResult TypeAnalyzer::visitBinaryScalarOp(OpTy op) {
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||||
|
knowledge.dtype = getPromotedResultType({op.a().getType(), op.b().getType()});
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
|
// `torch.aten.tensor` get a tensor from a list. Each layer of the list
|
||||||
// corresponds to one dim of the tensor.
|
// corresponds to one dim of the tensor.
|
||||||
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
ChangeResult TypeAnalyzer::visitAtenTensorOp(AtenTensorOp op) {
|
||||||
|
|
|
@ -74,3 +74,14 @@ func @torch.constant.int() -> !torch.int {
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
return %int1 : !torch.int
|
return %int1 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
||||||
|
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||||
|
// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]]
|
||||||
|
// CHECK: %[[INT:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
|
||||||
|
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[INT:.*]]
|
||||||
|
// CHECK: return %[[INT:.*]] : !torch.int
|
||||||
|
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
|
||||||
|
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue