diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 281637f15..4a234307d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ }]; } +def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRealOp : Torch_Op<"aten.real", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index faf3d0a25..553a8dc74 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10480,6 +10484,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6cb02297d..677ccc4f2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2415,6 +2415,50 @@ public: } // namespace +// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) +namespace { +class DecomposeAtenCeluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCeluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value alpha = op.getAlpha(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + + // positiveOutput = max(0,x) + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, input); + + // negativeOutput = min(0,alpha∗(exp(x/alpha)−1)) + Value scaledInput = + rewriter.create(loc, resType, input, alpha); + Value expX = rewriter.create(loc, resType, scaledInput); + Value expXM1 = rewriter.create(loc, resType, expX, + constantOne, constantOne); + Value scaledExpXM1 = + rewriter.create(loc, resType, expXM1, alpha); + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledExpXM1); + Value celuOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOne); + + rewriter.replaceOp(op, celuOutput); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenLerpScalarOp : public OpRewritePattern { public: @@ -7705,6 +7749,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa..e7bed6463 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 87344fb99..276cc47c1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -951,6 +951,7 @@ STABLEHLO_PASS_SET = { "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMinTensorFloatModule_basic", @@ -1571,6 +1572,8 @@ TOSA_PASS_SET = { "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6574c0bdc..da486fe46 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -526,6 +526,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2652,6 +2655,11 @@ def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu assert self_dtype == weight_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.)) +def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3ebd00753..6e449c277 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") + emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 47f4a6403..d034e6d1f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1016,6 +1016,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCeluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.celu(x, 0.5) + + +@register_test_case(module_factory=lambda: ElementwiseCeluModule()) +def ElementwiseCeluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseCeluStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.celu(x) + + +@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule()) +def ElementwiseCeluStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseGeluModule(torch.nn.Module): def __init__(self): super().__init__()