mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch Lowering for MaxUnpool op (#3413)
This commit also adds the Torch declaration for aten.max_unpool2d and aten.max_unpool3d op. The TorchToLinalg lowering for the same will be added in a follow-up commit. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3419/head
parent
89f7d24fdc
commit
35dd8c52cd
|
@ -6819,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$indices,
|
||||
AnyTorchListOfTorchIntType:$output_size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -6907,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$indices,
|
||||
AnyTorchListOfTorchIntType:$output_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1926,4 +1926,82 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// TODO: Add support for `output_shape` arg.
|
||||
if (binder.op->getNumOperands() == 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: output_shape arg is not supported");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data, indices;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorOperandAtIndex(indices, 1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "data/indices/resultType bind failure");
|
||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(data);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
int64_t rank = *maybeRank;
|
||||
int64_t spatial = rank - 2;
|
||||
|
||||
if (rank <= 3 || rank > 5)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: MaxUnpool support "
|
||||
"only present for rank 4/5 input");
|
||||
|
||||
if (!(resultType.hasSizes() && resultType.areAllSizesKnown()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected result to have all shapes "
|
||||
"statically known");
|
||||
|
||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||
Value resultShapeList =
|
||||
createConstantIntList(binder, rewriter, resultShape);
|
||||
if (rank == 4) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
|
||||
binder.op, resultType, data, indices, resultShapeList);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> padding, strides;
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
||||
if (!padding.empty() &&
|
||||
padding.size() != static_cast<size_t>(2 * spatial))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding list must contain (begin,end) pair for each "
|
||||
"spatial axis");
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
||||
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
|
||||
if (padding.empty())
|
||||
padding.resize(spatial, 0);
|
||||
if (strides.empty())
|
||||
strides.resize(spatial, 1);
|
||||
|
||||
// If the padding is symmetric we can push the padding
|
||||
// operation to the torch operator.
|
||||
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
||||
bool equal = true;
|
||||
for (int i = 0; i < spatial; ++i) {
|
||||
equal = equal && (padding[i] == padding[i + spatial]);
|
||||
}
|
||||
if (equal)
|
||||
padding.resize(spatial);
|
||||
}
|
||||
|
||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>(
|
||||
binder.op, resultType, data, indices, resultShapeList, stridesList,
|
||||
paddingList);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -597,6 +597,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
)
|
||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
|
||||
has_canonicalizer=True,
|
||||
|
@ -605,6 +606,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
|
||||
)
|
||||
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
)
|
||||
|
|
|
@ -1087,3 +1087,42 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc
|
|||
%0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1,6,7],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape
|
||||
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
|
||||
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
|
||||
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape
|
||||
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT4_1:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]], %[[INT4_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]], %[[INT0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
|
||||
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32>
|
||||
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue