[TORCH][MLIR] Fix ConstantPad2dStaticModule test.

This commit fixes the `ConstantPad2dStaticModule` test case by adding
the lowering of `aten.pad` operation. Previously the test case
mapped to `aten.constant_pad_nd` operation.
The `aten.pad` now decomposes into `aten.constant_pad_nd` operation.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/818/head snapshot-20220429.422
Prateek Gupta 2022-04-26 12:18:09 +00:00
parent 809f240f01
commit 81ee5bb58c
11 changed files with 154 additions and 19 deletions

View File

@ -3597,6 +3597,32 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
}];
}
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$pad,
Torch_StringType:$mode,
AnyTorchOptionalFloatType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenPadOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
AllowsTypeRefinement,
ReadOnly

View File

@ -355,6 +355,7 @@ class OptionalOf<Type type, string descr> :
def AnyTorchOptionalTensorType :
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
def AnyTorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
def AnyTorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
def AnyTorchOptionalBoolType:
OptionalOf<Torch_BoolType, "Optional torch bool type">;
def AnyTorchOptionalStringType:

View File

@ -13,6 +13,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
@ -1664,6 +1665,27 @@ public:
};
} // namespace
namespace {
// Decompose `aten.pad` op into `aten.constant_pad_nd` op.
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPadOp op,
PatternRewriter &rewriter) const override {
Value value = op.value();
if (value.getType().isa<Torch::OptionalType>())
return rewriter.notifyMatchFailure(op, "optional type not supported");
if (value.getType().isa<Torch::NoneType>())
value = rewriter.create<Torch::ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0));
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
op, op.getType(), op.self(), op.pad(), value);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1793,6 +1815,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenNewEmptyOp>(context);
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
patterns.add<DecomposeAtenPadOp>(context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -489,15 +489,15 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp,
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenZero_Op,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp,
AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
AtenIndexPutHackedTwinOp>(op)) {
ValueKnowledge knowledge =

View File

@ -2629,6 +2629,10 @@ module {
return %0 : !torch.tuple<list<int>, list<int>, list<int>>
}
func @"__torch_mlir_shape_fn.aten.constant_pad_nd"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
@ -2658,12 +2662,12 @@ module {
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
torch.prim.Loop %8, %true, init() {
^bb0(%arg3: !torch.int):
%9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int
^bb0(%arg2: !torch.int):
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
@ -2674,6 +2678,10 @@ module {
} : (!torch.int, !torch.bool) -> ()
return %arg0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.pad"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str, %arg3: !torch.optional<float>) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
%str = torch.constant.str "AssertionError: More indices than dimensions to index"
%none = torch.constant.none

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import string
from typing import List, Optional, Any, Tuple, Union
import os
@ -860,13 +861,10 @@ def atennative_batch_norm(input: List[int], weight: Optional[List[int]], bias
ErrorInvocation(TensorOfShape(2), [1]), # Unpaired pad value.
])
def atenconstant_pad_nd(self: List[int], pad: List[int], value: float = 0) -> List[int]:
assert len(pad) % 2 == 0, "Must have paired low-high pad amount values"
assert len(pad) // 2 <= len(self), "Number of padded dimensions must be less than or equal to the input dimension"
# The `pad` list takes the form of Low-high pairs starting at the
# *rightmost* dimension of `self`.
for i in range(len(pad) // 2):
self[-(i + 1)] += pad[2 * i] + pad[2 * i + 1]
return self
return upstream_shape_helpers.pad(self, pad)
def atenpad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
return upstream_shape_helpers.pad(self, pad)
@check_shape_function([
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.

View File

@ -32,6 +32,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
"bool[]": "AnyTorchListOfTorchBoolType",
"bool?": "AnyTorchOptionalBoolType",
"float": "Torch_FloatType",
"float?": "AnyTorchOptionalFloatType",
"t[]": "AnyTorchListType",
"t": "AnyTorchType",
"t1": "AnyTorchType",
@ -363,6 +364,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)

View File

@ -617,3 +617,12 @@ def quantized_prepacked_conv2d(input: List[int], conv2dOpContext: Any):
assert isinstance(conv2dOpContext, __torch__.torch.classes.quantized.Conv2dPackedParamsBase)
(weight, bias, stride, padding, dilation, groups) = unchecked_cast(Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int], ops.quantized.conv2d_unpack_sizes(conv2dOpContext))
return conv2d(input, weight, bias, stride, padding, dilation, groups)
def pad(input: List[int], pad: List[int]):
assert len(pad) % 2 == 0, "Must have paired low-high pad amount values"
assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension"
# The `pad` list takes the form of Low-high pairs starting at the
# *rightmost* dimension of `self`.
for i in range(len(pad) // 2):
input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1]
return input

View File

@ -15,7 +15,6 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"ConvolutionModule1D_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
"ConstantPad2dStaticModule_basic",
}
def register_all_tests():

View File

@ -275,6 +275,55 @@ class ConstantPad2dStaticModule(torch.nn.Module):
def ConstantPad2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
# ==============================================================================
class PadModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
pad = [0, 1, 2, 3]
mode = "constant"
return torch.ops.aten.pad(x, pad, mode, float(1.5))
@register_test_case(module_factory=lambda: PadModule())
def PadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
# ==============================================================================
class PadWithNoneValModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
pad = [0, 1, 2, 3]
mode = "constant"
return torch.ops.aten.pad(x, pad, mode, None)
@register_test_case(module_factory=lambda: PadWithNoneValModule())
def PadWithNoneValModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
# ==============================================================================

View File

@ -854,3 +854,22 @@ func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !to
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64>
}
// -----
// CHECK-LABEL: func @torch.aten.pad
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
// CHECK-NOT: torch.aten.pad
// CHECK: %[[STRING:.*]] = torch.constant.str "constant"
// CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]]
// CHECK-NEXT: return %[[PAD_ND]]
func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%str = torch.constant.str "constant"
%0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.pad %arg0, %0, %str, %arg1 : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[?,?,?],f64>
return %1 : !torch.vtensor<[?,?,?],f64>
}