diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 0a052451f..d1d4cea6d 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -310,3 +310,36 @@ class ElementwiseClampModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseClampModule()) def ElementwiseClampModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5, low=-10, high=10)) + + +class RsubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.rsub(x, 3.0, alpha=1.0) + +@register_test_case(module_factory=lambda: RsubModule()) +def RsubModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +class RsubModule_noalpha(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.rsub(x, 2.0) + +@register_test_case(module_factory=lambda: RsubModule_noalpha()) +def RsubModule_noalpha_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 1bd7ac5df..7913a2a61 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -824,6 +824,22 @@ def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; } +def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other `,` $alpha attr-dict `:` type($self) `,` type($other) `,` type($alpha) `->` type($result)"; +} + def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 37a4c5482..c3e4b7859 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1438,6 +1438,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } return result; } + if (auto rsub = dyn_cast(op)) { + if (!rsub.getType() + .cast() + .getDtype() + .isa()) { + rsub.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Value self = payloadArgs[0]; + Value other = promoteScalarToDtype(b, loc, operands[1], self.getType()); + Value alpha = promoteScalarToDtype(b, loc, operands[2], self.getType()); + Value mult = b.create(loc, self, alpha); + return b.create(loc, other, mult); + } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1647,7 +1661,7 @@ struct ConvertElementwiseOp : ConversionPattern { if (!isa(op)) + AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2783,7 +2797,7 @@ public: target.addIllegalOp(); + AtenMaximumOp, AtenClampOp, AtenRsubScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index eca847b8b..15352b2a1 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -230,7 +230,7 @@ public: DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp, - AtenLayerNormOp, AtenClampOp>(op)) { + AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4ca8c9353..88339861d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -469,8 +469,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): # variants. emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") - - + emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu : (Tensor) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")