[Torch] support AtenExp2Op (#3832)

- support AtenExp2Op by decomposing it to aten.pow.scalar
- refine stablehlo pow.scalar pow.Tensor_Scalar pow.Tensor_Tensor
lowering according to https://github.com/llvm/torch-mlir/pull/2983
- Close https://github.com/llvm/torch-mlir/pull/2983
pull/3842/head
yyp0 2024-10-31 19:14:05 +08:00 committed by GitHub
parent 4dd213b042
commit 9ce2a69703
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 153 additions and 95 deletions

View File

@ -996,6 +996,51 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [
}];
}
def Torch_AtenExp2Op : Torch_Op<"aten.exp2", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::exp2 : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExp2Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenExp2Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenExp2_Op : Torch_Op<"aten.exp2_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::exp2_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExp2_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenExp2_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -931,79 +931,49 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
return success();
}
// AtenPowTensorScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
namespace {
template <typename AtenOpT>
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = cast<TensorType>(
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
op.getType()));
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
auto outType = cast<TensorType>(
OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType()));
Value lhs = adaptor.getSelf();
auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
auto rhsType = dyn_cast<TensorType>(rhs.getType());
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
if (!lhsType && !rhsType) {
return op.emitError("only Tensor types supported in StableHLO");
}
if (!lhsType) {
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
DenseI64ArrayAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outType, lhs, rhs,
bcastDimensions);
return success();
}
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
rewriter.replaceOp(op, result);
return success();
}
// AtenPowScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
AtenPowScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = dyn_cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
auto rhsType = dyn_cast<TensorType>(rhs.getType());
if (!rhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = cast<TensorType>(
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
op.getType()));
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!lhsType) {
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
}
DenseI64ArrayAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
auto loc = op.getLoc();
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
bcastDimensions);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// PrimNumToTensorScalarOp
template <>
@ -1797,29 +1767,6 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
AtenPowTensorTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsTy = cast<TensorType>(lhs.getType());
Value rhs = adaptor.getExponent();
auto rhsTy = cast<TensorType>(rhs.getType());
if (!lhsTy || !rhsTy)
return op.emitError("only Tensor types supported");
auto outTy =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
return success();
}
// Converts `aten.empty.memory_format` to `tensor.empty` op.
template <>
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
@ -2250,6 +2197,14 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
#undef INSERT_BINARY_LOGICAL_PATTERN
#define INSERT_BINARY_POW_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context)
INSERT_BINARY_POW_PATTERN(AtenPowTensorScalarOp);
INSERT_BINARY_POW_PATTERN(AtenPowTensorTensorOp);
INSERT_BINARY_POW_PATTERN(AtenPowScalarOp);
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
@ -2260,8 +2215,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
@ -2285,7 +2238,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenFillScalarOp);

View File

@ -6487,6 +6487,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.exp2\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -11256,6 +11260,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exp2\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -9008,6 +9008,24 @@ class DecomposeAtenBinaryCrossEntropyWithLogitsOp
};
} // namespace
namespace {
class DecomposeAtenExp2Op : public OpRewritePattern<AtenExp2Op> {
using OpRewritePattern<AtenExp2Op>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExp2Op op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<AtenPowScalarOp>(op, op.getType(), two, self);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
@ -10146,6 +10164,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(

View File

@ -2707,6 +2707,7 @@ ONNX_XFAIL_SET = {
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"Exp2StaticModule_basic",
"MultinomialModule2D_basic",
"MultinomialModule2D_F32",
"PixelShuffleModuleStaticRank4Float32_basic",

View File

@ -216,6 +216,9 @@ def atensilu〡shape(self: List[int]) -> List[int]:
def atenexp〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenexp2〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenexpm1〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2567,6 +2570,11 @@ def atenexp〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenexp2〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenexpm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -317,6 +317,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::asin : (Tensor) -> (Tensor)",
"aten::asinh : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
"aten::exp2 : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::cosh : (Tensor) -> (Tensor)",

View File

@ -2881,6 +2881,29 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils):
# ==============================================================================
class Exp2StaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 2], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.exp2(x)
@register_test_case(module_factory=lambda: Exp2StaticModule())
def Exp2StaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2))
# ==============================================================================
class ElementwisePowModule(torch.nn.Module):
def __init__(self):
super().__init__()