From 81ee5bb58c57cb88ea69005db8ce655bc8b3bb78 Mon Sep 17 00:00:00 2001 From: Prateek Gupta Date: Tue, 26 Apr 2022 12:18:09 +0000 Subject: [PATCH] [TORCH][MLIR] Fix ConstantPad2dStaticModule test. This commit fixes the `ConstantPad2dStaticModule` test case by adding the lowering of `aten.pad` operation. Previously the test case mapped to `aten.constant_pad_nd` operation. The `aten.pad` now decomposes into `aten.constant_pad_nd` operation. Signed-Off-By: Prateek Gupta --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 1 + .../Torch/Transforms/DecomposeComplexOps.cpp | 24 +++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 14 +++--- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 16 ++++-- .../jit_ir/build_tools/shape_lib_gen.py | 12 ++--- .../jit_ir/build_tools/torch_ods_gen.py | 2 + .../build_tools/upstream_shape_helpers.py | 9 ++++ .../test_suite/__init__.py | 1 - .../torch_mlir_e2e_test/test_suite/basic.py | 49 +++++++++++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 19 +++++++ 11 files changed, 154 insertions(+), 19 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 37fc651c9..11fb93a9b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3597,6 +3597,32 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenPadOp : Torch_Op<"aten.pad", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$pad, + Torch_StringType:$mode, + AnyTorchOptionalFloatType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenPadOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [ AllowsTypeRefinement, ReadOnly diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index d961bb44c..6bc2b388e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -355,6 +355,7 @@ class OptionalOf : def AnyTorchOptionalTensorType : OptionalOf; def AnyTorchOptionalIntType: OptionalOf; +def AnyTorchOptionalFloatType: OptionalOf; def AnyTorchOptionalBoolType: OptionalOf; def AnyTorchOptionalStringType: diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7c9ba3533..45dc23e1e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/StringExtras.h" @@ -1664,6 +1665,27 @@ public: }; } // namespace +namespace { +// Decompose `aten.pad` op into `aten.constant_pad_nd` op. +class DecomposeAtenPadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPadOp op, + PatternRewriter &rewriter) const override { + + Value value = op.value(); + if (value.getType().isa()) + return rewriter.notifyMatchFailure(op, "optional type not supported"); + if (value.getType().isa()) + value = rewriter.create( + op.getLoc(), rewriter.getF64FloatAttr(0)); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.self(), op.pad(), value); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -1793,6 +1815,8 @@ class DecomposeComplexOpsPass patterns.add(context); patterns.add(context); target.addIllegalOp(); + target.addIllegalOp(); + patterns.add(context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 28667dfbb..2845641a7 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -489,15 +489,15 @@ ChangeResult TypeAnalyzer::visitOperation( AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp, ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp, ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, - AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, + AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, - Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, - AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, - AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, - AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, - AtenConstantPadNdOp, AtenZero_Op, - AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, + Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, + AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, + AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, + AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, + AtenPadOp, AtenZero_Op, AtenIndexTensorOp, + ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp, AtenIndexPutHackedTwinOp>(op)) { ValueKnowledge knowledge = diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 89d6d86a2..a7a912551 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2629,6 +2629,10 @@ module { return %0 : !torch.tuple, list, list> } func @"__torch_mlir_shape_fn.aten.constant_pad_nd"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } + func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -2658,12 +2662,12 @@ module { %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int torch.prim.Loop %8, %true, init() { - ^bb0(%arg3: !torch.int): - %9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int + ^bb0(%arg2: !torch.int): + %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int - %11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int + %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int @@ -2674,6 +2678,10 @@ module { } : (!torch.int, !torch.bool) -> () return %arg0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.pad"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list { %str = torch.constant.str "AssertionError: More indices than dimensions to index" %none = torch.constant.none diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 17f38277a..442f0d730 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +import string from typing import List, Optional, Any, Tuple, Union import os @@ -860,13 +861,10 @@ def aten〇native_batch_norm(input: List[int], weight: Optional[List[int]], bias ErrorInvocation(TensorOfShape(2), [1]), # Unpaired pad value. ]) def aten〇constant_pad_nd(self: List[int], pad: List[int], value: float = 0) -> List[int]: - assert len(pad) % 2 == 0, "Must have paired low-high pad amount values" - assert len(pad) // 2 <= len(self), "Number of padded dimensions must be less than or equal to the input dimension" - # The `pad` list takes the form of Low-high pairs starting at the - # *rightmost* dimension of `self`. - for i in range(len(pad) // 2): - self[-(i + 1)] += pad[2 * i] + pad[2 * i + 1] - return self + return upstream_shape_helpers.pad(self, pad) + +def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: + return upstream_shape_helpers.pad(self, pad) @check_shape_function([ Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d3147fd20..593d3c528 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -32,6 +32,7 @@ TORCH_TYPE_TO_ODS_TYPE = { "bool[]": "AnyTorchListOfTorchBoolType", "bool?": "AnyTorchOptionalBoolType", "float": "Torch_FloatType", + "float?": "AnyTorchOptionalFloatType", "t[]": "AnyTorchListType", "t": "AnyTorchType", "t1": "AnyTorchType", @@ -363,6 +364,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py index f62b5aa06..e0316ef0d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py @@ -617,3 +617,12 @@ def quantized_prepacked_conv2d(input: List[int], conv2dOpContext: Any): assert isinstance(conv2dOpContext, __torch__.torch.classes.quantized.Conv2dPackedParamsBase) (weight, bias, stride, padding, dilation, groups) = unchecked_cast(Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int], ops.quantized.conv2d_unpack_sizes(conv2dOpContext)) return conv2d(input, weight, bias, stride, padding, dilation, groups) + +def pad(input: List[int], pad: List[int]): + assert len(pad) % 2 == 0, "Must have paired low-high pad amount values" + assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension" + # The `pad` list takes the form of Low-high pairs starting at the + # *rightmost* dimension of `self`. + for i in range(len(pad) // 2): + input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1] + return input diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index b376afc81..c8696a3e4 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -15,7 +15,6 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "ConvolutionModule1D_basic", "MaxPool2dWith3dInputModule_basic", "MaxPool2dWithIndicesWith3dInputModule_basic", - "ConstantPad2dStaticModule_basic", } def register_all_tests(): diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 8550e6e46..46330f641 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -275,6 +275,55 @@ class ConstantPad2dStaticModule(torch.nn.Module): def ConstantPad2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20) - 0.5) +# ============================================================================== + + +class PadModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + pad = [0, 1, 2, 3] + mode = "constant" + return torch.ops.aten.pad(x, pad, mode, float(1.5)) + + +@register_test_case(module_factory=lambda: PadModule()) +def PadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20) - 0.5) + + +# ============================================================================== + + +class PadWithNoneValModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + pad = [0, 1, 2, 3] + mode = "constant" + return torch.ops.aten.pad(x, pad, mode, None) + + +@register_test_case(module_factory=lambda: PadWithNoneValModule()) +def PadWithNoneValModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20) - 0.5) + + + # ============================================================================== diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 4a88d4605..1ce568e28 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -854,3 +854,22 @@ func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !to %0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64> return %0 : !torch.vtensor<[?,?,?],f64> } + +// ----- +// CHECK-LABEL: func @torch.aten.pad +// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> { +// CHECK-NOT: torch.aten.pad +// CHECK: %[[STRING:.*]] = torch.constant.str "constant" +// CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct +// CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]] +// CHECK-NEXT: return %[[PAD_ND]] +func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %str = torch.constant.str "constant" + %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.pad %arg0, %0, %str, %arg1 : !torch.vtensor<[?,?,?],f64>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[?,?,?],f64> + return %1 : !torch.vtensor<[?,?,?],f64> +}