From 365655ca29c8823a9e41e3074b148b7c114d8846 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 2 Nov 2023 09:51:31 +0800 Subject: [PATCH] =?UTF-8?q?[Torch=20Dialect]=20add=20canonicalize=20patter?= =?UTF-8?q?n=20for=20aten.floor=20with=20integer=20=E2=80=A6=20(#2534)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …type --- e2e_testing/xfail_sets.py | 1 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 91 ++++++++++--------- lib/Dialect/Torch/IR/TorchOps.cpp | 16 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 2 +- .../test_suite/elementwise.py | 18 ++++ 5 files changed, 82 insertions(+), 46 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index c841a1611..7e5ec09d0 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -957,6 +957,7 @@ TOSA_PASS_SET = { "ElementwiseEluModule_basic", "ElementwiseEluNonDefaultModule_basic", "ElementwiseFloorModule_basic", + "ElementwiseFloorIntModule_basic", "ElementwiseLogModule_basic", "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseMinimumModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c22a252de..072fdd7df 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1023,51 +1023,6 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ }]; } -def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenFloorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenFloor_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ AllowsTypeRefinement, HasValueSemantics, @@ -3657,6 +3612,52 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ }]; } +def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFloorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFloor_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bf930d68b..721a676f7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1117,6 +1117,22 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenFloorOp +//===----------------------------------------------------------------------===// +void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) { + auto outputTy = op.getType().dyn_cast(); + if (outputTy && outputTy.hasDtype() && + outputTy.getDtype().isa()) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenMulScalarOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 1a39989f8..55e124d92 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -273,7 +273,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::atan : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", - "aten::floor : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", @@ -333,6 +332,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 918a22b86..00b07cbb6 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1420,6 +1420,24 @@ class ElementwiseFloorModule(torch.nn.Module): def ElementwiseFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +class ElementwiseFloorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.floor(a) + + +@register_test_case(module_factory=lambda: ElementwiseFloorIntModule()) +def ElementwiseFloorIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + # ==============================================================================