From a778f990e97de355af4be33d92dbb8cf77cffa19 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Sat, 11 Dec 2021 22:50:24 +0530 Subject: [PATCH] [TORCH][MLIR] Add E2E support for `aten.ceil` op This commit adds lowering of `aten.ceil` op as a part of element-wise ops lowering. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/elementwise.py | 16 +++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 28 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 9 ++++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 9bb5ac7c1..f297d3fdc 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -521,6 +521,22 @@ class ElementwiseFloorModule(torch.nn.Module): def ElementwiseFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +class ElementwiseCeilModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ceil(a) + +@register_test_case(module_factory=lambda: ElementwiseCeilModule()) +def ElementwiseCeilModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + class ElementwisePowModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 9c8220b8b..55d670491 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -298,6 +298,34 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 46531cfde..807e0b77a 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1512,6 +1512,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)) + return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) @@ -2067,7 +2069,8 @@ struct ConvertElementwiseOp : ConversionPattern { AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(op)) + AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp, + AtenCeilOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3635,8 +3638,8 @@ public: AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, - AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, + AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 5c57a240c..4563631e2 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -242,7 +242,7 @@ public: AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, - AtenAddIntOp, AtenAbsOp, AtenReciprocalOp>(op)) { + AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2c3d9b174..804cafaba 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -447,6 +447,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::cos : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", + "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",