mirror of https://github.com/llvm/torch-mlir
Generalize max_unpool lowering
parent
67732883fa
commit
9dcde43719
|
@ -7106,31 +7106,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self,
|
|
||||||
AnyTorchTensorType:$indices,
|
|
||||||
AnyTorchListOfTorchIntType:$output_size
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchOptionalTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
|
||||||
}
|
|
||||||
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 3, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
|
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -7219,12 +7194,12 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
|
def Torch_AtenMaxUnpoolOp : Torch_Op<"aten.max_unpool", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
|
let summary = "Generated op for `aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$self,
|
AnyTorchTensorType:$self,
|
||||||
AnyTorchTensorType:$indices,
|
AnyTorchTensorType:$indices,
|
||||||
|
@ -7237,10 +7212,10 @@ def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
|
||||||
);
|
);
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
let extraClassDefinition = [{
|
let extraClassDefinition = [{
|
||||||
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) {
|
ParseResult AtenMaxUnpoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
}
|
}
|
||||||
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) {
|
void AtenMaxUnpoolOp::print(OpAsmPrinter &printer) {
|
||||||
printDefaultTorchOp(printer, *this, 5, 1);
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||||
Value resultShapeList =
|
Value resultShapeList =
|
||||||
createConstantIntList(binder, rewriter, resultShape);
|
createConstantIntList(binder, rewriter, resultShape);
|
||||||
if (rank == 4) {
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
|
|
||||||
binder.op, resultType, data, indices, resultShapeList);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> padding, strides;
|
SmallVector<int64_t> padding, strides;
|
||||||
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
|
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
|
||||||
|
@ -3469,7 +3464,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpoolOp>(
|
||||||
binder.op, resultType, data, indices, resultShapeList, stridesList,
|
binder.op, resultType, data, indices, resultShapeList, stridesList,
|
||||||
paddingList);
|
paddingList);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -596,12 +596,12 @@ namespace {
|
||||||
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
|
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
|
||||||
// What worse, without knowing kernel size we cannot even reliably detect such
|
// What worse, without knowing kernel size we cannot even reliably detect such
|
||||||
// cases and this conversion will just return invalid values.
|
// cases and this conversion will just return invalid values.
|
||||||
class ConvertAtenMaxUnpool3dOp final
|
class ConvertAtenMaxUnpoolOp final
|
||||||
: public OpConversionPattern<AtenMaxUnpool3dOp> {
|
: public OpConversionPattern<AtenMaxUnpoolOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenMaxUnpoolOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -611,21 +611,22 @@ public:
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfType = cast<RankedTensorType>(self.getType());
|
auto selfType = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3);
|
size_t spatial = selfType.getRank() - 2;
|
||||||
|
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
|
||||||
if (ShapedType::isDynamicShape(inputSize))
|
if (ShapedType::isDynamicShape(inputSize))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"input type must be of static shape");
|
"input type must be of static shape");
|
||||||
|
|
||||||
Value indices = adaptor.getIndices();
|
Value indices = adaptor.getIndices();
|
||||||
auto indicesType = cast<RankedTensorType>(indices.getType());
|
auto indicesType = cast<RankedTensorType>(indices.getType());
|
||||||
if (inputSize != indicesType.getShape().take_back(3))
|
if (inputSize != indicesType.getShape().take_back(spatial))
|
||||||
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
|
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
|
||||||
|
|
||||||
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
|
auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
|
||||||
if (!resType)
|
if (!resType)
|
||||||
return rewriter.notifyMatchFailure(op, "invalid result type");
|
return rewriter.notifyMatchFailure(op, "invalid result type");
|
||||||
|
|
||||||
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3);
|
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
|
||||||
if (ShapedType::isDynamicShape(inferredOutSize))
|
if (ShapedType::isDynamicShape(inferredOutSize))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"output type must be of static shape");
|
"output type must be of static shape");
|
||||||
|
@ -636,7 +637,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int output");
|
"only support constant int output");
|
||||||
|
|
||||||
if (inferredOutSize != ArrayRef(output))
|
if (inferredOutSize != ArrayRef(output).take_back(spatial))
|
||||||
return rewriter.notifyMatchFailure(op, "Invalid output size");
|
return rewriter.notifyMatchFailure(op, "Invalid output size");
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> stride;
|
SmallVector<int64_t> stride;
|
||||||
|
@ -652,12 +653,12 @@ public:
|
||||||
|
|
||||||
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
|
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
|
||||||
// (padding.size() == 6).
|
// (padding.size() == 6).
|
||||||
if (stride.size() != 3 || padding.size() != 3)
|
if (stride.size() != spatial || padding.size() != spatial)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "stride and padding must be of size 3");
|
op, "stride and padding must be of size 3");
|
||||||
|
|
||||||
int64_t outRank = resType.getRank();
|
int64_t outRank = resType.getRank();
|
||||||
int64_t NC = outRank - 3;
|
int64_t NC = outRank - spatial;
|
||||||
|
|
||||||
for (auto &&[inDim, outDim, str, pad] :
|
for (auto &&[inDim, outDim, str, pad] :
|
||||||
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
|
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
|
||||||
|
@ -694,7 +695,7 @@ public:
|
||||||
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
|
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
|
||||||
// pad self and indices tensors to avoid out of bounds access.
|
// pad self and indices tensors to avoid out of bounds access.
|
||||||
SmallVector<int64_t> expectedInputShape =
|
SmallVector<int64_t> expectedInputShape =
|
||||||
llvm::to_vector(resType.getShape().drop_back(3));
|
llvm::to_vector(resType.getShape().drop_back(spatial));
|
||||||
for (auto &&[str, pad, resSize] :
|
for (auto &&[str, pad, resSize] :
|
||||||
llvm::zip_equal(stride, padding, inferredOutSize))
|
llvm::zip_equal(stride, padding, inferredOutSize))
|
||||||
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
|
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
|
||||||
|
@ -707,7 +708,7 @@ public:
|
||||||
SmallVector<int64_t> low(outRank, 0);
|
SmallVector<int64_t> low(outRank, 0);
|
||||||
SmallVector<int64_t> high(NC, 0);
|
SmallVector<int64_t> high(NC, 0);
|
||||||
for (auto &&[inpSize, outSize] : llvm::zip_equal(
|
for (auto &&[inpSize, outSize] : llvm::zip_equal(
|
||||||
inputSize, ArrayRef(expectedInputShape).take_back(3))) {
|
inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
|
||||||
high.emplace_back(outSize - inpSize);
|
high.emplace_back(outSize - inpSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1526,8 +1527,8 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
|
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenMaxUnpool3dOp>();
|
target.addIllegalOp<AtenMaxUnpoolOp>();
|
||||||
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxUnpoolOp>(typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
|
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
|
||||||
patterns
|
patterns
|
||||||
|
|
|
@ -8176,7 +8176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
" return %1 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
|
||||||
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
|
@ -12092,7 +12092,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
" return %1 : !torch.tuple<int, int>\n"
|
" return %1 : !torch.tuple<int, int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
|
|
@ -557,8 +557,8 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseRreluTrainStaticModule_basic",
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpoolModulePad0_basic",
|
||||||
"MaxUnpool3dModule_basic",
|
"MaxUnpoolModule_basic",
|
||||||
"MultinomialModule2D_F32",
|
"MultinomialModule2D_F32",
|
||||||
"MultinomialModule2D_basic",
|
"MultinomialModule2D_basic",
|
||||||
"MultinomialModule_basic",
|
"MultinomialModule_basic",
|
||||||
|
@ -2837,8 +2837,8 @@ ONNX_XFAIL_SET = {
|
||||||
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
"MaxPool3dWithIndicesNonDefaultDilationModule_basic",
|
||||||
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
"MaxPool3dWithIndicesNonDefaultParamsModule_basic",
|
||||||
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
"MaxPool3dWithIndicesNonDefaultStrideModule_basic",
|
||||||
"MaxUnpool3dModule_basic",
|
"MaxUnpoolModule_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpoolModulePad0_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
"Mlp1LayerModule_basic",
|
"Mlp1LayerModule_basic",
|
||||||
"Mlp2LayerModuleNoBias_basic",
|
"Mlp2LayerModuleNoBias_basic",
|
||||||
|
@ -3224,8 +3224,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpoolModulePad0_basic",
|
||||||
"MaxUnpool3dModule_basic",
|
"MaxUnpoolModule_basic",
|
||||||
"MultinomialModule2D_F32",
|
"MultinomialModule2D_F32",
|
||||||
"MultinomialModule2D_basic",
|
"MultinomialModule2D_basic",
|
||||||
"MultinomialModule_basic",
|
"MultinomialModule_basic",
|
||||||
|
@ -3892,8 +3892,8 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpoolModulePad0_basic",
|
||||||
"MaxUnpool3dModule_basic",
|
"MaxUnpoolModule_basic",
|
||||||
"MultinomialModule2D_F32",
|
"MultinomialModule2D_F32",
|
||||||
"MultinomialModule2D_basic",
|
"MultinomialModule2D_basic",
|
||||||
"MultinomialModule_basic",
|
"MultinomialModule_basic",
|
||||||
|
|
|
@ -1056,14 +1056,14 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
|
||||||
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
|
||||||
return maxpool3d, indices
|
return maxpool3d, indices
|
||||||
|
|
||||||
def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
def aten〇max_unpool〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
|
||||||
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
|
||||||
assert (len(output_size) == 3), "output_size must have 3 elements"
|
assert (len(output_size) == 3 or len(output_size) == 2), "output_size has 3 or 2 elements"
|
||||||
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
|
assert (len(self) == len(indices)), "Input and indices must be of the same rank"
|
||||||
if len(self) == 5:
|
if len(self) == 5:
|
||||||
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
|
return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
|
||||||
else:
|
else:
|
||||||
return [self[0], output_size[0], output_size[1], output_size[2]]
|
return [self[0], self[1], output_size[0], output_size[1]]
|
||||||
|
|
||||||
def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
|
||||||
return input_size
|
return input_size
|
||||||
|
@ -3179,7 +3179,7 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype, torch.int64
|
return self_dtype, torch.int64
|
||||||
|
|
||||||
def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
def aten〇max_unpool〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
|
|
@ -618,7 +618,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
)
|
)
|
||||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
|
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||||
has_canonicalizer=True,
|
has_canonicalizer=True,
|
||||||
|
@ -627,7 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||||
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
|
emit("aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1972,7 +1972,7 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class MaxUnpool3dModule(torch.nn.Module):
|
class MaxUnpoolModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1985,11 +1985,11 @@ class MaxUnpool3dModule(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, indices):
|
def forward(self, x, indices):
|
||||||
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1))
|
return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MaxUnpool3dModule())
|
@register_test_case(module_factory=lambda: MaxUnpoolModule())
|
||||||
def MaxUnpool3dModule_basic(module, tu: TestUtils):
|
def MaxUnpoolModule_basic(module, tu: TestUtils):
|
||||||
input = tu.rand(2, 2, 4, 5, 6)
|
input = tu.rand(2, 2, 4, 5, 6)
|
||||||
pool = torch.nn.MaxPool3d(
|
pool = torch.nn.MaxPool3d(
|
||||||
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True
|
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True
|
||||||
|
@ -2000,7 +2000,7 @@ def MaxUnpool3dModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
# We have a special case for all-zeros padding, test it too.
|
# We have a special case for all-zeros padding, test it too.
|
||||||
class MaxUnpool3dModulePad0(torch.nn.Module):
|
class MaxUnpoolModulePad0(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -2013,11 +2013,11 @@ class MaxUnpool3dModulePad0(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, indices):
|
def forward(self, x, indices):
|
||||||
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0))
|
return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0())
|
@register_test_case(module_factory=lambda: MaxUnpoolModulePad0())
|
||||||
def MaxUnpool3dModulePad0_basic(module, tu: TestUtils):
|
def MaxUnpoolModulePad0_basic(module, tu: TestUtils):
|
||||||
input = tu.rand(2, 2, 4, 5, 6)
|
input = tu.rand(2, 2, 4, 5, 6)
|
||||||
pool = torch.nn.MaxPool3d(
|
pool = torch.nn.MaxPool3d(
|
||||||
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True
|
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True
|
||||||
|
|
|
@ -1667,14 +1667,20 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape
|
// CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape
|
||||||
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
|
||||||
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
|
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
|
||||||
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
|
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
|
||||||
return %0 : !torch.vtensor<[1,1,4,4],f32>
|
return %0 : !torch.vtensor<[1,1,4,4],f32>
|
||||||
|
@ -1682,8 +1688,8 @@ func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape
|
// CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape
|
||||||
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||||
|
@ -1698,7 +1704,7 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
|
||||||
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
|
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
|
||||||
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32>
|
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32>
|
||||||
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
|
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
|
||||||
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
|
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
|
||||||
|
|
|
@ -95,3 +95,46 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
|
||||||
// CHECK: } -> tensor<?x?x?x?x?xf32>
|
// CHECK: } -> tensor<?x?x?x?x?xf32>
|
||||||
return %4 : !torch.vtensor<[?,?,?,?,?],f32>
|
return %4 : !torch.vtensor<[?,?,?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
|
||||||
|
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
|
func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int1_0 = torch.constant.int 1
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int4_1 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int0_2 = torch.constant.int 0
|
||||||
|
%1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int2_3 = torch.constant.int 2
|
||||||
|
%2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.aten.max_unpool %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
|
||||||
|
// CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64>
|
||||||
|
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>
|
||||||
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
|
||||||
|
// CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4x4xf32>
|
||||||
|
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor<?x?x4x4xf32>) {
|
||||||
|
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||||
|
// CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index
|
||||||
|
// CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index
|
||||||
|
// CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index
|
||||||
|
// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index
|
||||||
|
// CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index
|
||||||
|
// CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index
|
||||||
|
// CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index
|
||||||
|
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32
|
||||||
|
// CHECK: } -> tensor<?x?x4x4xf32>
|
||||||
|
// CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor<?x?x4x4xf32> to tensor<1x1x4x4xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x1x4x4xf32> -> !torch.vtensor<[1,1,4,4],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
|
||||||
|
return %3 : !torch.vtensor<[1,1,4,4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue