From 0be15605de152ccaf73966f73e3b386e8eb75b64 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 15 Nov 2024 11:19:10 +0000 Subject: [PATCH] [MLIR][TORCH] Add aten.special.expm1 op lowering This commit adds the support for torch.aten.special.expm1 op by decomposing it into torch.aten.expm1 op. Signed-off-by: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 9 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 14 ++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 12 +++-- .../build_tools/abstract_interp_lib_gen.py | 8 +++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 49 ++++++++++++++++++- 8 files changed, 111 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c5b491197..93d72763d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4610,6 +4610,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } +def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 7a0a24a28..53a891fbb 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6495,6 +6495,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11435,6 +11439,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9db8a6949..0f88242ba 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10501,6 +10501,19 @@ public: }; } // namespace +namespace { +class DecomposeAtenSpecialExpm1Op + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -10776,6 +10789,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 4bca74470..6818b5d2a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -566,6 +566,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b5d02034c..b49caf5f8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -501,8 +501,6 @@ FX_IMPORTER_XFAIL_SET = { "AdaptiveMaxPool1dStatic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", @@ -928,8 +926,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - "ElementwiseExpm1IntModule_basic", - "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "IscloseStaticModuleTrue_basic", @@ -1226,6 +1222,8 @@ STABLEHLO_PASS_SET = { "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", @@ -2913,6 +2911,8 @@ ONNX_XFAIL_SET = { "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -3630,6 +3630,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseSigmoidIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", @@ -4342,6 +4344,8 @@ ONNX_TOSA_XFAIL_SET = { "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e78b3d49d..77a5a90a2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -222,6 +222,9 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇special_expm1〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇isfinite〡shape(self: List[int]) -> List[int]: return self @@ -2656,6 +2659,11 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 371b73347..7dedffc16 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -452,6 +452,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a6679ec4d..e767d90ef 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -5207,7 +5207,7 @@ class ElementwiseExpm1Module(torch.nn.Module): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1Module()) @@ -5230,7 +5230,7 @@ class ElementwiseExpm1IntModule(torch.nn.Module): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) @@ -5241,6 +5241,51 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSpecialExpm1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) +def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSpecialExpm1IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) +def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + +# ============================================================================== + + class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__()