From 9ce2a697034c51715c22a19b88209480f36fc976 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 31 Oct 2024 19:14:05 +0800 Subject: [PATCH] [Torch] support AtenExp2Op (#3832) - support AtenExp2Op by decomposing it to aten.pow.scalar - refine stablehlo pow.scalar pow.Tensor_Scalar pow.Tensor_Tensor lowering according to https://github.com/llvm/torch-mlir/pull/2983 - Close https://github.com/llvm/torch-mlir/pull/2983 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 ++++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 142 ++++++------------ .../Transforms/AbstractInterpLibrary.cpp | 9 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 19 +++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 23 +++ 8 files changed, 153 insertions(+), 95 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5ec6a4d1d..199003e72 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -996,6 +996,51 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ }]; } +def Torch_AtenExp2Op : Torch_Op<"aten.exp2", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::exp2 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExp2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExp2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenExp2_Op : Torch_Op<"aten.exp2_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::exp2_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExp2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExp2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index ab4e284f8..4f521fb9e 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -931,79 +931,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// AtenPowTensorScalarOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsType = dyn_cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - TensorType rhsType = dyn_cast(rhs.getType()); +namespace { +template +class ConvertAtenPowOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); - if (!lhsType) - return op.emitError("only Tensor types supported in StableHLO"); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } - auto outType = cast( - OpConversionPattern::getTypeConverter() - ->convertType(op.getType())); + Value lhs = adaptor.getSelf(); + auto lhsType = dyn_cast(lhs.getType()); + Value rhs = adaptor.getExponent(); + auto rhsType = dyn_cast(rhs.getType()); - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); + if (!lhsType && !rhsType) { + return op.emitError("only Tensor types supported in StableHLO"); + } + if (!lhsType) { + lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); + } + if (!rhsType) { + rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); + } + + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); + DenseI64ArrayAttr bcastDimensions; + rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, + bcastDimensions); + return success(); } - - if (!rhsType) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); - } - DenseI64ArrayAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); - auto loc = op.getLoc(); - Value result = rewriter.create(loc, outType, lhs, rhs, - bcastDimensions); - - rewriter.replaceOp(op, result); - return success(); -} - -// AtenPowScalarOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsType = dyn_cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - auto rhsType = dyn_cast(rhs.getType()); - - if (!rhsType) - return op.emitError("only Tensor types supported in StableHLO"); - - auto outType = cast( - OpConversionPattern::getTypeConverter()->convertType( - op.getType())); - - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - - if (!lhsType) { - lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); - } - DenseI64ArrayAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); - auto loc = op.getLoc(); - Value result = rewriter.create(loc, outType, lhs, rhs, - bcastDimensions); - - rewriter.replaceOp(op, result); - return success(); -} +}; +} // namespace // PrimNumToTensorScalarOp template <> @@ -1797,29 +1767,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsTy = cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - auto rhsTy = cast(rhs.getType()); - - if (!lhsTy || !rhsTy) - return op.emitError("only Tensor types supported"); - - auto outTy = - cast(this->getTypeConverter()->convertType(op.getType())); - - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType()); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType()); - - rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, - /*broadcast_attr*/ nullptr); - return success(); -} - // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2250,6 +2197,14 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #undef INSERT_BINARY_LOGICAL_PATTERN +#define INSERT_BINARY_POW_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + INSERT_BINARY_POW_PATTERN(AtenPowTensorScalarOp); + INSERT_BINARY_POW_PATTERN(AtenPowTensorTensorOp); + INSERT_BINARY_POW_PATTERN(AtenPowScalarOp); +#undef INSERT_BINARY_ADDSUB_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) @@ -2260,8 +2215,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); @@ -2285,7 +2238,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1765786be..b978c3472 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6487,6 +6487,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exp2\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11256,6 +11260,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp2\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1fefb59a4..9006f1660 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9008,6 +9008,24 @@ class DecomposeAtenBinaryCrossEntropyWithLogitsOp }; } // namespace +namespace { +class DecomposeAtenExp2Op : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExp2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto two = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp(op, op.getType(), two, self); + + return success(); + } +}; + +} // namespace + namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -10146,6 +10164,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 854c2d871..d4f470ab4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2707,6 +2707,7 @@ ONNX_XFAIL_SET = { "ElementwiseLog2IntModule_basic", "ElementwiseFminModule_basic", "ElementwiseFmaxModule_basic", + "Exp2StaticModule_basic", "MultinomialModule2D_basic", "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 36ab8fe2c..1bb4266d5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -216,6 +216,9 @@ def aten〇silu〡shape(self: List[int]) -> List[int]: def aten〇exp〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇exp2〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2567,6 +2570,11 @@ def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇exp2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 311636c82..5f614de59 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -317,6 +317,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::asin : (Tensor) -> (Tensor)", "aten::asinh : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", + "aten::exp2 : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", "aten::cosh : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a62b901a9..e9098698f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2881,6 +2881,29 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils): # ============================================================================== +class Exp2StaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 2], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.exp2(x) + + +@register_test_case(module_factory=lambda: Exp2StaticModule()) +def Exp2StaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2)) + + +# ============================================================================== + + class ElementwisePowModule(torch.nn.Module): def __init__(self): super().__init__()