[ONNX] Basic Support for DeformConv (#3469)

This adds a torchvision op to torch-mlir and a path from onnx.DeformConv
to torchvision.deform_conv2d.

I'm not implementing the torch->linalg lowering for the torchvision op
yet, but posting this PR to get feedback on some of the choices being
made here and to flesh out the onnx frontend a bit.
pull/3495/head
zjgarvey 2024-06-25 12:16:51 -05:00 committed by GitHub
parent e346c911f7
commit 368fabf0c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 328 additions and 9 deletions

View File

@ -16660,6 +16660,42 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
}];
}
def Torch_TorchvisionDeformConv2dOp : Torch_Op<"torchvision.deform_conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchTensorType:$offset,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$bias,
Torch_IntType:$stride_h,
Torch_IntType:$stride_w,
Torch_IntType:$pad_h,
Torch_IntType:$pad_w,
Torch_IntType:$dilation_h,
Torch_IntType:$dilation_w,
Torch_IntType:$groups,
Torch_IntType:$offset_groups,
Torch_BoolType:$use_mask
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchvisionDeformConv2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 14, 1);
}
void TorchvisionDeformConv2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 14, 1);
}
}];
}
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1837,6 +1837,141 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, transposedInput, reshapeSizesList);
return success();
});
patterns.onOp(
"DeformConv", 19,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
auto loc = binder.getLoc();
// get operands
llvm::SmallVector<Value> operands;
Torch::ValueTensorType resultType;
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType))
return failure();
if (operands.size() < 3 || operands.size() > 5)
return failure();
auto inputType =
dyn_cast<Torch::ValueTensorType>(operands[0].getType());
if (!inputType || !inputType.hasSizes() ||
inputType.getSizes().size() != 4)
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: DeformConv with input rank != 4");
unsigned rank = inputType.getSizes().size();
auto weightType =
dyn_cast<Torch::ValueTensorType>(operands[1].getType());
if (!weightType || !weightType.hasSizes())
return failure();
auto offsetType =
dyn_cast<Torch::ValueTensorType>(operands[2].getType());
if (!offsetType || !offsetType.hasSizes())
return failure();
// get attributes
SmallVector<int64_t> dilations, kernelShape, pads, strides;
SmallVector<int64_t> defaultDilations(rank - 2, 0);
SmallVector<int64_t> defaultPads(2 * (rank - 2), 0);
SmallVector<int64_t> defaultStrides(rank - 2, 1);
int64_t group, offsetGroup;
if (binder.s64IntegerArrayAttr(dilations, "dilations",
defaultDilations) ||
binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) ||
binder.s64IntegerArrayAttr(pads, "pads", defaultPads) ||
binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) ||
binder.s64IntegerAttr(group, "group", 1) ||
binder.s64IntegerAttr(offsetGroup, "offset_group", 1))
return failure();
for (unsigned i = 0; i < rank - 2; i++) {
if (pads[i] != pads[rank + i - 2])
return rewriter.notifyMatchFailure(
binder.op, "unsupported: asymmetric padding");
}
// Identify and assign names to operands
Value input, weight, offset, bias, mask;
bool useMask = false;
input = operands[0];
weight = operands[1];
offset = operands[2];
if (operands.size() == 4) {
auto unknownOpdRank = Torch::getTensorRank(operands[3]);
if (!unknownOpdRank)
return failure();
if (*unknownOpdRank == 1)
bias = operands[3];
else if (*unknownOpdRank == rank) {
mask = operands[3];
useMask = true;
} else
llvm_unreachable("onnx.DeformConv: optional 4th operand of "
"unexpected rank encountered");
}
if (operands.size() == 5) {
bias = operands[3];
mask = operands[4];
useMask = true;
}
// assign default operand values if necessary
ArrayRef<int64_t> weightSizes = weightType.getSizes();
ArrayRef<int64_t> offsetSizes = offsetType.getSizes();
if (!bias) {
int64_t outputChannels = weightSizes[0];
SmallVector<int64_t> biasShape(1, outputChannels);
Value biasShapeList = mlir::torch::onnx_c::createConstantIntList(
binder, rewriter, biasShape);
Value cstZero = Torch::getConstantWithGivenDtypeAndValue(
rewriter, loc, 0.0f, inputType.getDtype());
bias =
Torch::createInitTensor(rewriter, loc,
rewriter.getType<Torch::ValueTensorType>(
biasShape, inputType.getDtype()),
cstZero, biasShapeList);
}
if (!mask) {
int64_t batchSize = inputType.getSizes()[0];
int64_t kernelHeight = weightSizes[2];
int64_t kernelWidth = weightSizes[3];
int64_t outputHeight = offsetSizes[2];
int64_t outputWidth = offsetSizes[3];
int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth;
SmallVector<int64_t> maskShape(
{batchSize, maskDimOne, outputHeight, outputWidth});
Value cstOne = Torch::getConstantWithGivenDtypeAndValue(
rewriter, loc, 1.0f, inputType.getDtype());
Value maskShapeList = mlir::torch::onnx_c::createConstantIntList(
binder, rewriter, maskShape);
mask =
Torch::createInitTensor(rewriter, loc,
rewriter.getType<Torch::ValueTensorType>(
maskShape, inputType.getDtype()),
cstOne, maskShapeList);
}
// get attributes as constant values
SmallVector<Value> dilationValues, padValues, strideValues;
for (auto i : dilations)
dilationValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
for (auto i : pads)
padValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
for (auto i : strides)
strideValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
Value groupValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(group));
Value offsetGroupValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(offsetGroup));
Value useMaskValue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(useMask));
rewriter.replaceOpWithNewOp<Torch::TorchvisionDeformConv2dOp>(
binder.op, resultType, input, weight, offset, mask, bias,
strideValues[0], strideValues[1], padValues[0], padValues[1],
dilationValues[0], dilationValues[1], groupValue, offsetGroupValue,
useMaskValue);
return success();
});
patterns.onOp(
"DequantizeLinear", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -9492,6 +9492,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.torchvision.deform_conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg2, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.torchvision.deform_conv2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.tuple<int, int>, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -29,6 +29,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"InterpolateDynamicModule_scales_recompute_bilinear",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"AtenIntMM_basic",
# unimplemented lowering torch -> linalg for torchvision.deform_conv2d
# this is added to check the torch.onnx.export -> import_onnx -> torch path
"DeformConv2D_basic",
}
LINALG_CRASHING_SET = {
@ -383,6 +386,7 @@ FX_IMPORTER_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"CumsumModule_basic",
"DeformConv2D_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -554,6 +558,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"CumsumModule_basic",
"DeformConv2D_basic",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
@ -2357,19 +2362,12 @@ ONNX_XFAIL_SET = {
"DivIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAndScalarModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseOrModule_basic",
@ -2710,6 +2708,8 @@ ONNX_XFAIL_SET = {
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# unimplemented torchvision.deform_conv2d torch->linalg
"DeformConv2D_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
@ -2759,6 +2759,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
"ElementwiseBitwiseLeftShiftInt32Module_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
# bitwise and support has been added in torch nightly
"ElementwiseAndScalarModule_basic",
"ElementwiseAndScalarStaticShapeModule_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
}
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
@ -2930,6 +2938,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"CumsumModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"DeformConv2D_basic",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
@ -3724,6 +3733,7 @@ ONNX_TOSA_XFAIL_SET = {
"CumsumModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"DeformConv2D_basic",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",

View File

@ -8,7 +8,6 @@ import argparse
import os
import torch
import torchvision
from torch import device
import torch.jit._shape_functions as upstream_shape_functions
@ -1639,6 +1638,12 @@ def atenview_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
assert False, "Unsupported dtype"
def torchvisiondeform_conv2d〡shape(input: List[int], weight: List[int], offset: List[int], mask: List[int], bias: List[int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> List[int]:
return [input[0], weight[0], offset[2], offset[3]]
def torchvisiondeform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], offset_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> int:
return input_rank_dtype[1]
def atenconv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]:
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
@ -5117,6 +5122,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
def main(args):
_maybe_import_op_extensions(args)
# importing torchvision will register torchvision ops with the JITOperatorRegistry
import torchvision
asm = generate_library(globals())
# We're about to put quotes around the string, so escape the `"` characters.
asm = asm.replace("\"", "\\\"")

View File

@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
traits=["HasValueSemantics"],
)
# ==========================================================================
# `torchvision::` namespace.
# ==========================================================================
emit(
"torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)"
)
emit(
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
)
@ -1180,6 +1187,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
def main(args: argparse.Namespace):
_maybe_import_op_extensions(args)
# importing torchvision will register torchvision ops with the JITOperatorRegistry
import torchvision
registry = Registry.load()

View File

@ -9,6 +9,7 @@ from typing import Any
import io
import onnx
import torch
from torch.onnx._constants import ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET as max_opset_ver
import torch_mlir
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -78,7 +79,12 @@ def convert_onnx(model, inputs):
examples = tuple(examples)
torch.onnx.export(
model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors
model,
examples,
buffer,
input_names=input_names,
dynamic_axes=dynamic_tensors,
opset_version=max_opset_ver,
)
buffer = buffer.getvalue()
return import_onnx(buffer)

View File

@ -1256,3 +1256,90 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)
# torchvision.deform_conv2d
import torchvision
# This section defines a torch->onnx path for this torchvision op so we can test the onnx paths e2e.
# Create symbolic function
from torch.onnx.symbolic_helper import parse_args, _get_tensor_sizes
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b")
def symbolic_deform_conv2d_forward(
g,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):
args = [input, weight, offset, bias]
if use_mask:
args.append(mask)
weight_size = _get_tensor_sizes(weight)
kwargs = {
"dilations_i": [dilation_h, dilation_w],
"group_i": groups,
"kernel_shape_i": weight_size[2:],
"offset_group_i": offset_groups,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": [pad_h, pad_w, pad_h, pad_w],
"strides_i": [stride_h, stride_w],
}
return g.op("DeformConv", *args, **kwargs)
# Register symbolic function
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic(
"torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 19
)
N = 1
Cin = 1
Hin = 7
Win = 6
Cout = 1
Hker = 2
Wker = 2
offset_groups = 1
Hout = 6
Wout = 5
offset_dim1 = 2 * offset_groups * Hker * Wker
class DeformableConvModule(torch.nn.Module):
@export
@annotate_args(
[
None,
([N, Cin, Hin, Win], torch.float32, True),
([N, offset_dim1, Hout, Wout], torch.float32, True),
([Cout, Cin, Hker, Wker], torch.float32, True),
]
)
def forward(self, input, offset, weight):
return torchvision.ops.deform_conv2d(input, offset, weight)
@register_test_case(module_factory=lambda: DeformableConvModule())
def DeformConv2D_basic(module, tu: TestUtils):
input = tu.rand(N, Cin, Hin, Win)
offset = tu.rand(N, offset_dim1, Hout, Wout)
weight = tu.rand(Cout, Cin, Hker, Wker)
module.forward(input, offset, weight)

View File

@ -735,6 +735,19 @@ func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4
// -----
// CHECK-LABEL: @test_deform_conv
func.func @test_deform_conv(%arg0: !torch.vtensor<[1,1,7,6],f32>, %arg1: !torch.vtensor<[1,8,6,5],f32>, %arg2: !torch.vtensor<[1,1,2,2],f32>, %arg3: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
// CHECK: %[[cstOne:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[mask:.*]] = torch.aten.full %[[sizeList:.*]], %[[cstOne]]
// CHECK-SAME: -> !torch.vtensor<[1,4,6,5],f32>
// CHECK: torch.torchvision.deform_conv2d %arg0, %arg2, %arg1, %[[mask]], %arg3
// CHECK-SAME: : !torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1,4,6,5],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[1,1,6,5],f32>
%1 = torch.operator "onnx.DeformConv"(%arg0, %arg2, %arg1, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.offset_group = 1 : si64, torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32>
return %1 : !torch.vtensor<[1,1,6,5],f32>
}
// -----
// CHECK-LABEL: @test_dequantizelinear_si8
func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32>