diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index e7b128a8d..caf15bb0e 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -378,3 +378,19 @@ class ElementwiseSqrtModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSqrtModule()) def ElementwiseSqrtModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +class ElementwiseFloorModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.floor(a) + +@register_test_case(module_factory=lambda: ElementwiseFloorModule()) +def ElementwiseFloorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 66354639a..58bf93e5d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -240,6 +240,34 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; } +def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + +def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; +} + def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c41c533ae..db0aa91c5 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1278,6 +1278,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)) + return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) @@ -1666,7 +1668,7 @@ struct ConvertElementwiseOp : ConversionPattern { AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenSqrtOp>(op)) + AtenSqrtOp, AtenFloorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2803,7 +2805,7 @@ public: AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenSqrtOp>(); + AtenSqrtOp, AtenFloorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 687aa2520..886ca2c2a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -231,7 +231,7 @@ public: AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, - AtenSqrtOp>(op)) { + AtenSqrtOp, AtenFloorOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d41b9ba31..4c24a1c90 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -445,6 +445,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::exp : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", + "aten::floor : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",