diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 85c47c0c9..b40eaaafd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -565,6 +565,51 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ }]; } +def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::expm1_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExpm1_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExpm1_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenCosOp : Torch_Op<"aten.cos", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5efe5e185..b982ca54b 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -131,6 +131,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -923,15 +927,15 @@ public: if (!isa(op)) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8f08621f6..1eff5b7ca 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -663,9 +663,9 @@ void TypeAnalysis::visitOperation(Operation *op, } // Dtype is always float32, except for bfloat16, float64 and nullptr. - if (isa(op)) { + if (isa(op)) { ValueKnowledge knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Type dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index b060c8744..ad5274493 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5337,6 +5337,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.expm1"(%arg0: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 0009f3d30..1aa13ef5e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -315,6 +315,9 @@ def aten〇silu(self: List[int]) -> List[int]: def aten〇exp(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇expm1(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇sin(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d3d4784c8..48ee9722f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -250,6 +250,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", + "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index f6df3f66a..f13a0e172 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1528,6 +1528,50 @@ def ElementwiseExpIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseExpm1Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpm1Module()) +def ElementwiseExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseExpm1IntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.special.expm1(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) +def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + + +# ============================================================================== + + class ElementwiseSinModule(torch.nn.Module): def __init__(self):