[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)

Adds OnnxToTorch lowering for the `onnx.Col2Im` op.
pull/3406/merge
Vinayak Dev 2024-06-13 14:12:06 +05:30 committed by GitHub
parent de7f058a0e
commit 39d882f7c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 591 additions and 0 deletions

View File

@ -12398,6 +12398,34 @@ def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [
let hasVerifier = 1; let hasVerifier = 1;
} }
def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$dilation,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$stride
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCol2imOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenCol2imOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -949,6 +949,130 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return failure(); return failure();
}); });
patterns.onOp(
"Col2Im", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, blockShape, imageShape;
SmallVector<int64_t> dilations, strides, pads;
// TODO: The length of dilations should be len(imageShape), and the same
// goes for strides. The length of pads should be 2 * len(imageShape).
// But, as at the moment we are only supporting 3D or 4D input,
// len(imageShape) must necessarily be 2, hence the lengths of the
// default values.
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(imageShape, 1) ||
binder.tensorOperandAtIndex(blockShape, 2) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerArrayAttr(dilations, "dilations",
SmallVector<int64_t>{1, 1}) ||
binder.s64IntegerArrayAttr(strides, "strides",
SmallVector<int64_t>{1, 1}) ||
binder.s64IntegerArrayAttr(pads, "pads",
SmallVector<int64_t>{0, 0, 0, 0}))
return failure();
auto imageShapeTy = cast<Torch::ValueTensorType>(imageShape.getType());
auto imageShapeSizes = imageShapeTy.getSizes();
auto blockShapeTy = cast<Torch::ValueTensorType>(blockShape.getType());
auto blockShapeSizes = blockShapeTy.getSizes();
// Check that neither imageShape nor blockShape have dynamic shapes.
if (imageShapeSizes[0] == Torch::kUnknownSize ||
blockShapeSizes[0] == Torch::kUnknownSize) {
return rewriter.notifyMatchFailure(
binder.op,
"Dynamic shapes are not allowed for imageShape and blockShape");
}
// TODO: Add support for 5D input tensors.
if (imageShapeSizes[0] != 2) {
return rewriter.notifyMatchFailure(
binder.op, "Expected length of imageShape to be equal to 2");
}
if (blockShapeSizes[0] != 2) {
return rewriter.notifyMatchFailure(
binder.op, "Expected length of blockShape to be equal to 2");
}
if (dilations.size() != 2) {
return rewriter.notifyMatchFailure(
binder.op, "Expected length of dilations to be equal to 2");
}
if (strides.size() != 2) {
return rewriter.notifyMatchFailure(
binder.op, "Expected length of strides to be equal to 2");
}
// TODO: Disable this check and add support for different
// paddings on lower and higher ends of each axis.
// Because we have already checked that imageShape has 2 elements,
// we can safely assume that len(padding) will be 4.
if (pads[0] != pads[2] || pads[1] != pads[3])
return rewriter.notifyMatchFailure(
binder.op, "padding on the lower end and the higher end "
"on each axis should be the same");
// Since we know that the padding on the lower end and the higher
// end on each axis is the same, we can reduce the size of the
// padding list, and filter out the duplicate elements.
// (Also, Torch::AtenCol2imOp requires len(padding) to be 2).
SmallVector<int64_t> padOnEachAxis = {pads[0], pads[1]};
Value dilationsList =
createConstantIntList(binder, rewriter, dilations);
Value stridesList = createConstantIntList(binder, rewriter, strides);
Value paddingList =
createConstantIntList(binder, rewriter, padOnEachAxis);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
// Index the imageShape and blockShape tensors, as AtenCol2imOp expects
// them to be int lists.
auto select = [&](Value v, Value k,
Torch::ValueTensorType ty) -> Value {
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(),
Torch::ValueTensorType::get(
binder.op->getContext(), ArrayRef<int64_t>{1},
rewriter.getIntegerType(64, /*signed*/ 1)),
k);
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
binder.getLoc(),
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
ty.getOptionalDtype()),
v, zero, kTensor);
Value item = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
return item;
};
SmallVector<Value> imageShapeContainer, blockShapeContainer;
for (int64_t i = 0; i < imageShapeSizes[0]; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
// Passing in the shapeType of each of these tensors avoids
// repeated casts, as these have already been calculated.
imageShapeContainer.push_back(select(imageShape, k, imageShapeTy));
blockShapeContainer.push_back(select(blockShape, k, blockShapeTy));
}
Value imageShapeAsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
imageShapeContainer);
Value blockShapeAsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
blockShapeContainer);
rewriter.replaceOpWithNewOp<Torch::AtenCol2imOp>(
binder.op, resultType, input, imageShapeAsList, blockShapeAsList,
dilationsList, paddingList, stridesList);
return success();
});
patterns.onOp( patterns.onOp(
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad; std::string autoPad;

View File

@ -9016,6 +9016,248 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n" " return %1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.col2im\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Expected size of input's dimension 2 to match the calculated number of sliding blocks\"\n"
" %str_0 = torch.constant.str \"AssertionError: Expected size of input's dimension 1 to be divisible by the product of kernel_size\"\n"
" %int-1 = torch.constant.int -1\n"
" %str_1 = torch.constant.str \"AssertionError: stride must be greater than 0\"\n"
" %str_2 = torch.constant.str \"AssertionError: padding should be non negative\"\n"
" %str_3 = torch.constant.str \"AssertionError: dilation should be greater than 0\"\n"
" %str_4 = torch.constant.str \"AssertionError: kernel size should be greater than 0\"\n"
" %str_5 = torch.constant.str \"AssertionError: padding is expected to have length 2\"\n"
" %str_6 = torch.constant.str \"AssertionError: stride is expected to have length 2\"\n"
" %str_7 = torch.constant.str \"AssertionError: dilation is expected to have length 2\"\n"
" %str_8 = torch.constant.str \"AssertionError: kernel_size is expected to have length 2\"\n"
" %str_9 = torch.constant.str \"AssertionError: output_size is expected to have length 2\"\n"
" %none = torch.constant.none\n"
" %str_10 = torch.constant.str \"AssertionError: Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input\"\n"
" %true = torch.constant.bool true\n"
" %false = torch.constant.bool false\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %75 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" %76 = torch.prim.If %75 -> (!torch.bool) {\n"
" %78 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %79 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %77 = torch.prim.If %76 -> (!torch.bool) {\n"
" %78 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %79 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %77 : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_10, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %6 = torch.aten.eq.int %5, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_9, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %8 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_8, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %12 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %13 = torch.aten.len.t %arg4 : !torch.list<int> -> !torch.int\n"
" %14 = torch.aten.eq.int %13, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %14 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %15 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.gt.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %17 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %18 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %20 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %21 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %22 = torch.aten.ge.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %23 = torch.prim.If %22 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.ge.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %23 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %24 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %25 = torch.aten.gt.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %26 = torch.prim.If %25 -> (!torch.bool) {\n"
" %75 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %76 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %26 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %27 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" %28 = torch.prim.If %27 -> (!torch.int) {\n"
" torch.prim.If.yield %int0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int-1 : !torch.int\n"
" }\n"
" %29 = torch.aten.add.int %28, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %30 = torch.aten.__getitem__.t %arg0, %29 : !torch.list<int>, !torch.int -> !torch.int\n"
" %31 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %32 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %33 = torch.aten.mul.int %31, %32 : !torch.int, !torch.int -> !torch.int\n"
" %34 = torch.aten.remainder.int %30, %33 : !torch.int, !torch.int -> !torch.int\n"
" %35 = torch.aten.eq.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %35 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %36 = torch.aten.add.int %28, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %37 = torch.aten.__getitem__.t %arg0, %36 : !torch.list<int>, !torch.int -> !torch.int\n"
" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %39 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %40 = torch.aten.mul.int %int2, %39 : !torch.int, !torch.int -> !torch.int\n"
" %41 = torch.aten.add.int %38, %40 : !torch.int, !torch.int -> !torch.int\n"
" %42 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %43 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %44 = torch.aten.sub.int %43, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %45 = torch.aten.mul.int %42, %44 : !torch.int, !torch.int -> !torch.int\n"
" %46 = torch.aten.sub.int %41, %45 : !torch.int, !torch.int -> !torch.int\n"
" %47 = torch.aten.sub.int %46, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %48 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %49 = torch.aten.floordiv.int %47, %48 : !torch.int, !torch.int -> !torch.int\n"
" %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %51 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %52 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %53 = torch.aten.mul.int %int2, %52 : !torch.int, !torch.int -> !torch.int\n"
" %54 = torch.aten.add.int %51, %53 : !torch.int, !torch.int -> !torch.int\n"
" %55 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %56 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %57 = torch.aten.sub.int %56, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %58 = torch.aten.mul.int %55, %57 : !torch.int, !torch.int -> !torch.int\n"
" %59 = torch.aten.sub.int %54, %58 : !torch.int, !torch.int -> !torch.int\n"
" %60 = torch.aten.sub.int %59, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %61 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %62 = torch.aten.floordiv.int %60, %61 : !torch.int, !torch.int -> !torch.int\n"
" %63 = torch.aten.add.int %62, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %64 = torch.aten.mul.int %50, %63 : !torch.int, !torch.int -> !torch.int\n"
" %65 = torch.aten.eq.int %37, %64 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %65 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %66 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %67 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %68 = torch.aten.mul.int %66, %67 : !torch.int, !torch.int -> !torch.int\n"
" %69 = torch.aten.floordiv.int %30, %68 : !torch.int, !torch.int -> !torch.int\n"
" %70 = torch.aten.eq.int %28, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %71 = torch.prim.If %70 -> (!torch.list<int>) {\n"
" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.prim.ListConstruct %75, %69 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %76 : !torch.list<int>\n"
" } else {\n"
" %75 = torch.prim.ListConstruct %69 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %75 : !torch.list<int>\n"
" }\n"
" %72 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %73 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %73, %true, init() {\n"
" ^bb0(%arg6: !torch.int):\n"
" %75 = torch.aten.__getitem__.t %arg1, %arg6 : !torch.list<int>, !torch.int -> !torch.int\n"
" %76 = torch.aten.append.t %72, %75 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %74 = torch.operator \"aten.add_.t\"(%71, %72) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> \n"
" return %74 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n" " %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n" " return %0 : !torch.tuple<list<int>, list<int>>\n"
@ -12049,6 +12291,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.col2im\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n" " %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n" " return %int4 : !torch.int\n"

View File

@ -1501,6 +1501,46 @@ def atenaddcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[in
def atenaddcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]: def atenaddcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2)) return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2))
@check_shape_function([
Invocation(TensorOfShape(1,5,5), [5,5], [1,5], [1,1], [0,0], [1,1]), # basic case
Invocation(TensorOfShape(1,4,5), [6,6], [2,2], [1,5], [0,0], [1,1]), # dilation
Invocation(TensorOfShape(1,5,15), [5,5], [1,5], [1,1], [0,1], [1,1]), # padding
Invocation(TensorOfShape(1,9,4), [5,5], [3,3], [1,1], [0,0], [2,2]), # stride
ErrorInvocation(TensorOfShape(1,5,5), [5,5], [1,7], [1,1], [0,0], [1,1]), # mismatch of sliding blocks
])
def atencol2im〡shape(self: List[int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> List[int]:
ndim = len(self)
assert (ndim == 2 and self[0] != 0 and self[1] != 0) or (ndim == 3 and self[1] != 0 and self[2] != 0), "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input"
assert len(output_size) == 2, "output_size is expected to have length 2"
assert len(kernel_size) == 2, "kernel_size is expected to have length 2"
assert len(dilation) == 2, "dilation is expected to have length 2"
assert len(stride) == 2, "stride is expected to have length 2"
assert len(padding) == 2, "padding is expected to have length 2"
assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size should be greater than 0"
assert dilation[0] > 0 and dilation[1] > 0, "dilation should be greater than 0"
assert padding[0] >= 0 and padding[1] >= 0, "padding should be non negative"
assert stride[0] > 0 and stride[1] > 0, "stride must be greater than 0"
batch_dim = 0 if ndim == 3 else -1
n_input_plane = self[batch_dim + 1]
assert n_input_plane % (kernel_size[0] * kernel_size[1]) == 0, "Expected size of input's dimension 1 to be divisible by the product of kernel_size"
input_length = self[batch_dim + 2]
n_blocks_height = (output_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
n_blocks_width = (output_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
assert input_length == n_blocks_height * n_blocks_width, "Expected size of input's dimension 2 to match the calculated number of sliding blocks"
# compute the shape of the output
num_channels = n_input_plane // (kernel_size[0] * kernel_size[1])
out: List[int] = [self[0], num_channels] if batch_dim == 0 else [num_channels]
out += [elem for elem in output_size]
return out
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape(2, 3), 1), # Basic case. Invocation(TensorOfShape(2, 3), 1), # Basic case.
Invocation(TensorOfShape(2, 3), 2, dim=0), # Test explicit `dim`. Invocation(TensorOfShape(2, 3), 2, dim=0), # Test explicit `dim`.
@ -3708,6 +3748,16 @@ def atenbincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
return torch.int64 return torch.int64
return torch.float64 return torch.float64
@check_dtype_function([
Invocation(TensorOfShape(1, 5, 5, dtype=torch.int64), [5,5], [1,5], [1,1], [0,0], [1,1]), # int type
Invocation(TensorOfShape(1, 5, 5, dtype=torch.float64), [5,5], [1,5], [1,1], [0,0], [1,1]), # float type
Invocation(TensorOfShape(1, 5, 5, dtype=torch.complex64), [5,5], [1,5], [1,1], [0,0], [1,1]), # complex type
Invocation(TensorOfShape(1, 5, 5, dtype=torch.bool), [5,5], [1,5], [1,1], [0,0], [1,1]), # boolean type
])
def atencol2im〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> int:
_, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu"))) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu")))
def atennonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atennonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int64 return torch.int64

View File

@ -911,6 +911,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)" "aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
) )
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
# Functionalization ops # Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias_copy : (Tensor) -> (Tensor)")

View File

@ -2325,3 +2325,145 @@ func.func @test_hammingwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<
%0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> %0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
return %0 : !torch.vtensor<[10],f32> return %0 : !torch.vtensor<[10],f32>
} }
// -----
// CHECK-LABEL: func.func @test_col2im
func.func @test_col2im(%arg0: !torch.vtensor<[1,5,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,5,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
return %0 : !torch.vtensor<[1,1,5,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_col2im_pads
func.func @test_col2im_pads(%arg0: !torch.vtensor<[1,5,15],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[INT1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,15],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
// CHECK: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.pads = [0 : si64, 1 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[1,5,15],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
return %0 : !torch.vtensor<[1,1,5,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_col2im_dilations
func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT5_0:.*]] = torch.constant.int 5
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT5_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[INT1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,4,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,6,6],f32>
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,6,6],f32>
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 5 : si64]} : (!torch.vtensor<[1,4,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32>
return %0 : !torch.vtensor<[1,1,6,6],f32>
}
// CHECK-LABEL: func.func @test_col2im_strides
func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK-DAG: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,9,4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
return %0 : !torch.vtensor<[1,1,5,5],f32>
}