mirror of https://github.com/llvm/torch-mlir
[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)
Adds OnnxToTorch lowering for the `onnx.Col2Im` op.pull/3406/merge
parent
de7f058a0e
commit
39d882f7c9
|
@ -12398,6 +12398,34 @@ def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$output_size,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$stride
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenCol2imOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenCol2imOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -949,6 +949,130 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Col2Im", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, blockShape, imageShape;
|
||||
SmallVector<int64_t> dilations, strides, pads;
|
||||
|
||||
// TODO: The length of dilations should be len(imageShape), and the same
|
||||
// goes for strides. The length of pads should be 2 * len(imageShape).
|
||||
// But, as at the moment we are only supporting 3D or 4D input,
|
||||
// len(imageShape) must necessarily be 2, hence the lengths of the
|
||||
// default values.
|
||||
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||
binder.tensorOperandAtIndex(imageShape, 1) ||
|
||||
binder.tensorOperandAtIndex(blockShape, 2) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerArrayAttr(dilations, "dilations",
|
||||
SmallVector<int64_t>{1, 1}) ||
|
||||
binder.s64IntegerArrayAttr(strides, "strides",
|
||||
SmallVector<int64_t>{1, 1}) ||
|
||||
binder.s64IntegerArrayAttr(pads, "pads",
|
||||
SmallVector<int64_t>{0, 0, 0, 0}))
|
||||
return failure();
|
||||
|
||||
auto imageShapeTy = cast<Torch::ValueTensorType>(imageShape.getType());
|
||||
auto imageShapeSizes = imageShapeTy.getSizes();
|
||||
|
||||
auto blockShapeTy = cast<Torch::ValueTensorType>(blockShape.getType());
|
||||
auto blockShapeSizes = blockShapeTy.getSizes();
|
||||
|
||||
// Check that neither imageShape nor blockShape have dynamic shapes.
|
||||
if (imageShapeSizes[0] == Torch::kUnknownSize ||
|
||||
blockShapeSizes[0] == Torch::kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Dynamic shapes are not allowed for imageShape and blockShape");
|
||||
}
|
||||
|
||||
// TODO: Add support for 5D input tensors.
|
||||
if (imageShapeSizes[0] != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected length of imageShape to be equal to 2");
|
||||
}
|
||||
if (blockShapeSizes[0] != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected length of blockShape to be equal to 2");
|
||||
}
|
||||
if (dilations.size() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected length of dilations to be equal to 2");
|
||||
}
|
||||
if (strides.size() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected length of strides to be equal to 2");
|
||||
}
|
||||
|
||||
// TODO: Disable this check and add support for different
|
||||
// paddings on lower and higher ends of each axis.
|
||||
// Because we have already checked that imageShape has 2 elements,
|
||||
// we can safely assume that len(padding) will be 4.
|
||||
if (pads[0] != pads[2] || pads[1] != pads[3])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding on the lower end and the higher end "
|
||||
"on each axis should be the same");
|
||||
|
||||
// Since we know that the padding on the lower end and the higher
|
||||
// end on each axis is the same, we can reduce the size of the
|
||||
// padding list, and filter out the duplicate elements.
|
||||
// (Also, Torch::AtenCol2imOp requires len(padding) to be 2).
|
||||
SmallVector<int64_t> padOnEachAxis = {pads[0], pads[1]};
|
||||
Value dilationsList =
|
||||
createConstantIntList(binder, rewriter, dilations);
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
Value paddingList =
|
||||
createConstantIntList(binder, rewriter, padOnEachAxis);
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// Index the imageShape and blockShape tensors, as AtenCol2imOp expects
|
||||
// them to be int lists.
|
||||
auto select = [&](Value v, Value k,
|
||||
Torch::ValueTensorType ty) -> Value {
|
||||
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), ArrayRef<int64_t>{1},
|
||||
rewriter.getIntegerType(64, /*signed*/ 1)),
|
||||
k);
|
||||
|
||||
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
||||
ty.getOptionalDtype()),
|
||||
v, zero, kTensor);
|
||||
Value item = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
|
||||
return item;
|
||||
};
|
||||
|
||||
SmallVector<Value> imageShapeContainer, blockShapeContainer;
|
||||
for (int64_t i = 0; i < imageShapeSizes[0]; ++i) {
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||
|
||||
// Passing in the shapeType of each of these tensors avoids
|
||||
// repeated casts, as these have already been calculated.
|
||||
imageShapeContainer.push_back(select(imageShape, k, imageShapeTy));
|
||||
blockShapeContainer.push_back(select(blockShape, k, blockShapeTy));
|
||||
}
|
||||
|
||||
Value imageShapeAsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
imageShapeContainer);
|
||||
Value blockShapeAsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
blockShapeContainer);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenCol2imOp>(
|
||||
binder.op, resultType, input, imageShapeAsList, blockShapeAsList,
|
||||
dilationsList, paddingList, stridesList);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
|
|
|
@ -9016,6 +9016,248 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.col2im\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Expected size of input's dimension 2 to match the calculated number of sliding blocks\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: Expected size of input's dimension 1 to be divisible by the product of kernel_size\"\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: stride must be greater than 0\"\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: padding should be non negative\"\n"
|
||||
" %str_3 = torch.constant.str \"AssertionError: dilation should be greater than 0\"\n"
|
||||
" %str_4 = torch.constant.str \"AssertionError: kernel size should be greater than 0\"\n"
|
||||
" %str_5 = torch.constant.str \"AssertionError: padding is expected to have length 2\"\n"
|
||||
" %str_6 = torch.constant.str \"AssertionError: stride is expected to have length 2\"\n"
|
||||
" %str_7 = torch.constant.str \"AssertionError: dilation is expected to have length 2\"\n"
|
||||
" %str_8 = torch.constant.str \"AssertionError: kernel_size is expected to have length 2\"\n"
|
||||
" %str_9 = torch.constant.str \"AssertionError: output_size is expected to have length 2\"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_10 = torch.constant.str \"AssertionError: Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %75 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %76 = torch.prim.If %75 -> (!torch.bool) {\n"
|
||||
" %78 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %79 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %77 = torch.prim.If %76 -> (!torch.bool) {\n"
|
||||
" %78 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %79 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %77 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_10, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %6 = torch.aten.eq.int %5, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %6 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_9, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %7 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %8 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_8, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %10 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %11 = torch.aten.len.t %arg5 : !torch.list<int> -> !torch.int\n"
|
||||
" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %12 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %13 = torch.aten.len.t %arg4 : !torch.list<int> -> !torch.int\n"
|
||||
" %14 = torch.aten.eq.int %13, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %14 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.gt.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %17 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %18 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %20 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %21 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %22 = torch.aten.ge.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %23 = torch.prim.If %22 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.ge.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %23 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %25 = torch.aten.gt.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %26 = torch.prim.If %25 -> (!torch.bool) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %76 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %26 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %27 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %28 = torch.prim.If %27 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int0 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int-1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %29 = torch.aten.add.int %28, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %30 = torch.aten.__getitem__.t %arg0, %29 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %31 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %32 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %33 = torch.aten.mul.int %31, %32 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %34 = torch.aten.remainder.int %30, %33 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %35 = torch.aten.eq.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %35 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %36 = torch.aten.add.int %28, %int2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %37 = torch.aten.__getitem__.t %arg0, %36 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %39 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %40 = torch.aten.mul.int %int2, %39 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %41 = torch.aten.add.int %38, %40 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %42 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %43 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %44 = torch.aten.sub.int %43, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %45 = torch.aten.mul.int %42, %44 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %46 = torch.aten.sub.int %41, %45 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %47 = torch.aten.sub.int %46, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %48 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %49 = torch.aten.floordiv.int %47, %48 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %51 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %52 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %53 = torch.aten.mul.int %int2, %52 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %54 = torch.aten.add.int %51, %53 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %55 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %56 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %57 = torch.aten.sub.int %56, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %58 = torch.aten.mul.int %55, %57 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %59 = torch.aten.sub.int %54, %58 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %60 = torch.aten.sub.int %59, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %61 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %62 = torch.aten.floordiv.int %60, %61 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %63 = torch.aten.add.int %62, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %64 = torch.aten.mul.int %50, %63 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %65 = torch.aten.eq.int %37, %64 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %65 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %66 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %67 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %68 = torch.aten.mul.int %66, %67 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %69 = torch.aten.floordiv.int %30, %68 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %70 = torch.aten.eq.int %28, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %71 = torch.prim.If %70 -> (!torch.list<int>) {\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.prim.ListConstruct %75, %69 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %76 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %75 = torch.prim.ListConstruct %69 : (!torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %75 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" %72 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %73 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %73, %true, init() {\n"
|
||||
" ^bb0(%arg6: !torch.int):\n"
|
||||
" %75 = torch.aten.__getitem__.t %arg1, %arg6 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %76 = torch.aten.append.t %72, %75 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %74 = torch.operator \"aten.add_.t\"(%71, %72) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int> \n"
|
||||
" return %74 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||
|
@ -12049,6 +12291,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.col2im\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
|
|
|
@ -1501,6 +1501,46 @@ def aten〇addcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[in
|
|||
def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2))
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1,5,5), [5,5], [1,5], [1,1], [0,0], [1,1]), # basic case
|
||||
Invocation(TensorOfShape(1,4,5), [6,6], [2,2], [1,5], [0,0], [1,1]), # dilation
|
||||
Invocation(TensorOfShape(1,5,15), [5,5], [1,5], [1,1], [0,1], [1,1]), # padding
|
||||
Invocation(TensorOfShape(1,9,4), [5,5], [3,3], [1,1], [0,0], [2,2]), # stride
|
||||
ErrorInvocation(TensorOfShape(1,5,5), [5,5], [1,7], [1,1], [0,0], [1,1]), # mismatch of sliding blocks
|
||||
])
|
||||
def aten〇col2im〡shape(self: List[int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> List[int]:
|
||||
ndim = len(self)
|
||||
assert (ndim == 2 and self[0] != 0 and self[1] != 0) or (ndim == 3 and self[1] != 0 and self[2] != 0), "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input"
|
||||
|
||||
assert len(output_size) == 2, "output_size is expected to have length 2"
|
||||
assert len(kernel_size) == 2, "kernel_size is expected to have length 2"
|
||||
assert len(dilation) == 2, "dilation is expected to have length 2"
|
||||
assert len(stride) == 2, "stride is expected to have length 2"
|
||||
assert len(padding) == 2, "padding is expected to have length 2"
|
||||
|
||||
assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size should be greater than 0"
|
||||
assert dilation[0] > 0 and dilation[1] > 0, "dilation should be greater than 0"
|
||||
assert padding[0] >= 0 and padding[1] >= 0, "padding should be non negative"
|
||||
assert stride[0] > 0 and stride[1] > 0, "stride must be greater than 0"
|
||||
|
||||
batch_dim = 0 if ndim == 3 else -1
|
||||
n_input_plane = self[batch_dim + 1]
|
||||
|
||||
assert n_input_plane % (kernel_size[0] * kernel_size[1]) == 0, "Expected size of input's dimension 1 to be divisible by the product of kernel_size"
|
||||
|
||||
input_length = self[batch_dim + 2]
|
||||
n_blocks_height = (output_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
|
||||
n_blocks_width = (output_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
|
||||
assert input_length == n_blocks_height * n_blocks_width, "Expected size of input's dimension 2 to match the calculated number of sliding blocks"
|
||||
|
||||
# compute the shape of the output
|
||||
num_channels = n_input_plane // (kernel_size[0] * kernel_size[1])
|
||||
out: List[int] = [self[0], num_channels] if batch_dim == 0 else [num_channels]
|
||||
out += [elem for elem in output_size]
|
||||
|
||||
return out
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3), 1), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3), 2, dim=0), # Test explicit `dim`.
|
||||
|
@ -3708,6 +3748,16 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
|
|||
return torch.int64
|
||||
return torch.float64
|
||||
|
||||
@check_dtype_function([
|
||||
Invocation(TensorOfShape(1, 5, 5, dtype=torch.int64), [5,5], [1,5], [1,1], [0,0], [1,1]), # int type
|
||||
Invocation(TensorOfShape(1, 5, 5, dtype=torch.float64), [5,5], [1,5], [1,1], [0,0], [1,1]), # float type
|
||||
Invocation(TensorOfShape(1, 5, 5, dtype=torch.complex64), [5,5], [1,5], [1,1], [0,0], [1,1]), # complex type
|
||||
Invocation(TensorOfShape(1, 5, 5, dtype=torch.bool), [5,5], [1,5], [1,1], [0,0], [1,1]), # boolean type
|
||||
])
|
||||
def aten〇col2im〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> int:
|
||||
_, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu")))
|
||||
def aten〇nonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.int64
|
||||
|
|
|
@ -911,6 +911,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||
emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)")
|
||||
|
||||
# Functionalization ops
|
||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -2325,3 +2325,145 @@ func.func @test_hammingwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<
|
|||
%0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32>
|
||||
return %0 : !torch.vtensor<[10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_col2im
|
||||
func.func @test_col2im(%arg0: !torch.vtensor<[1,5,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
|
||||
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
|
||||
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,5,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_col2im_pads
|
||||
func.func @test_col2im_pads(%arg0: !torch.vtensor<[1,5,15],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[INT1_5:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,15],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
|
||||
// CHECK: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
|
||||
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.pads = [0 : si64, 1 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[1,5,15],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_col2im_dilations
|
||||
func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT5_0:.*]] = torch.constant.int 5
|
||||
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT5_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[INT1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,4,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,6,6],f32>
|
||||
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,6,6],f32>
|
||||
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 5 : si64]} : (!torch.vtensor<[1,4,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32>
|
||||
return %0 : !torch.vtensor<[1,1,6,6],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_col2im_strides
|
||||
func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT2_0:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[INT2_1:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,9,4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],f32>
|
||||
// CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32>
|
||||
%0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue