mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Fix ConstantPad2dStaticModule test.
This commit fixes the `ConstantPad2dStaticModule` test case by adding the lowering of `aten.pad` operation. Previously the test case mapped to `aten.constant_pad_nd` operation. The `aten.pad` now decomposes into `aten.constant_pad_nd` operation. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/818/head snapshot-20220429.422
parent
809f240f01
commit
81ee5bb58c
|
@ -3597,6 +3597,32 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::pad : (Tensor, int[], str, float?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$pad,
|
||||
Torch_StringType:$mode,
|
||||
AnyTorchOptionalFloatType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenPadOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenPadOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -355,6 +355,7 @@ class OptionalOf<Type type, string descr> :
|
|||
def AnyTorchOptionalTensorType :
|
||||
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
|
||||
def AnyTorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
|
||||
def AnyTorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
|
||||
def AnyTorchOptionalBoolType:
|
||||
OptionalOf<Torch_BoolType, "Optional torch bool type">;
|
||||
def AnyTorchOptionalStringType:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -1664,6 +1665,27 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.pad` op into `aten.constant_pad_nd` op.
|
||||
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenPadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value value = op.value();
|
||||
if (value.getType().isa<Torch::OptionalType>())
|
||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||
if (value.getType().isa<Torch::NoneType>())
|
||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
||||
op, op.getType(), op.self(), op.pad(), value);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -1793,6 +1815,8 @@ class DecomposeComplexOpsPass
|
|||
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
||||
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
patterns.add<DecomposeAtenPadOp>(context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -489,15 +489,15 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulli_TensorOp,
|
||||
ValsemVariantAtenBernoulliFloatOp, ValsemVariantAtenBernoulliTensorOp,
|
||||
ValsemVariantAtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenMaxPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp,
|
||||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
||||
AtenConstantPadNdOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
|
||||
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
|
||||
AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp,
|
||||
AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||
AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
|
||||
ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, ValsemVariantAtenZeroOp,
|
||||
AtenIndexPutHackedTwinOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
|
|
|
@ -2629,6 +2629,10 @@ module {
|
|||
return %0 : !torch.tuple<list<int>, list<int>, list<int>>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.constant_pad_nd"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -2658,12 +2662,12 @@ module {
|
|||
%7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.Loop %8, %true, init() {
|
||||
^bb0(%arg3: !torch.int):
|
||||
%9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
^bb0(%arg2: !torch.int):
|
||||
%9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%10 = torch.aten.neg.int %9 : !torch.int -> !torch.int
|
||||
%11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
|
||||
%11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
|
||||
%12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int
|
||||
%13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int
|
||||
%14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
%15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int
|
||||
|
@ -2674,6 +2678,10 @@ module {
|
|||
} : (!torch.int, !torch.bool) -> ()
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.pad"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str, %arg3: !torch.optional<float>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.pad(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
|
||||
%str = torch.constant.str "AssertionError: More indices than dimensions to index"
|
||||
%none = torch.constant.none
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import string
|
||||
from typing import List, Optional, Any, Tuple, Union
|
||||
|
||||
import os
|
||||
|
@ -860,13 +861,10 @@ def aten〇native_batch_norm(input: List[int], weight: Optional[List[int]], bias
|
|||
ErrorInvocation(TensorOfShape(2), [1]), # Unpaired pad value.
|
||||
])
|
||||
def aten〇constant_pad_nd(self: List[int], pad: List[int], value: float = 0) -> List[int]:
|
||||
assert len(pad) % 2 == 0, "Must have paired low-high pad amount values"
|
||||
assert len(pad) // 2 <= len(self), "Number of padded dimensions must be less than or equal to the input dimension"
|
||||
# The `pad` list takes the form of Low-high pairs starting at the
|
||||
# *rightmost* dimension of `self`.
|
||||
for i in range(len(pad) // 2):
|
||||
self[-(i + 1)] += pad[2 * i] + pad[2 * i + 1]
|
||||
return self
|
||||
return upstream_shape_helpers.pad(self, pad)
|
||||
|
||||
def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
|
||||
return upstream_shape_helpers.pad(self, pad)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.
|
||||
|
|
|
@ -32,6 +32,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"bool[]": "AnyTorchListOfTorchBoolType",
|
||||
"bool?": "AnyTorchOptionalBoolType",
|
||||
"float": "Torch_FloatType",
|
||||
"float?": "AnyTorchOptionalFloatType",
|
||||
"t[]": "AnyTorchListType",
|
||||
"t": "AnyTorchType",
|
||||
"t1": "AnyTorchType",
|
||||
|
@ -363,6 +364,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -617,3 +617,12 @@ def quantized_prepacked_conv2d(input: List[int], conv2dOpContext: Any):
|
|||
assert isinstance(conv2dOpContext, __torch__.torch.classes.quantized.Conv2dPackedParamsBase)
|
||||
(weight, bias, stride, padding, dilation, groups) = unchecked_cast(Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int], ops.quantized.conv2d_unpack_sizes(conv2dOpContext))
|
||||
return conv2d(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
def pad(input: List[int], pad: List[int]):
|
||||
assert len(pad) % 2 == 0, "Must have paired low-high pad amount values"
|
||||
assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension"
|
||||
# The `pad` list takes the form of Low-high pairs starting at the
|
||||
# *rightmost* dimension of `self`.
|
||||
for i in range(len(pad) // 2):
|
||||
input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1]
|
||||
return input
|
||||
|
|
|
@ -15,7 +15,6 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"ConvolutionModule1D_basic",
|
||||
"MaxPool2dWith3dInputModule_basic",
|
||||
"MaxPool2dWithIndicesWith3dInputModule_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
}
|
||||
|
||||
def register_all_tests():
|
||||
|
|
|
@ -275,6 +275,55 @@ class ConstantPad2dStaticModule(torch.nn.Module):
|
|||
def ConstantPad2dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class PadModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
mode = "constant"
|
||||
return torch.ops.aten.pad(x, pad, mode, float(1.5))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: PadModule())
|
||||
def PadModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class PadWithNoneValModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
mode = "constant"
|
||||
return torch.ops.aten.pad(x, pad, mode, None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: PadWithNoneValModule())
|
||||
def PadWithNoneValModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -854,3 +854,22 @@ func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !to
|
|||
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
|
||||
return %0 : !torch.vtensor<[?,?,?],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.pad
|
||||
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
|
||||
// CHECK-NOT: torch.aten.pad
|
||||
// CHECK: %[[STRING:.*]] = torch.constant.str "constant"
|
||||
// CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct
|
||||
// CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]]
|
||||
// CHECK-NEXT: return %[[PAD_ND]]
|
||||
func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%str = torch.constant.str "constant"
|
||||
%0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.pad %arg0, %0, %str, %arg1 : !torch.vtensor<[?,?,?],f64>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[?,?,?],f64>
|
||||
return %1 : !torch.vtensor<[?,?,?],f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue