Implement Expand/Collapse Functionality for Aten.View (#1353)

pull/1444/head
JakopinA 2022-09-27 13:08:14 -05:00 committed by GitHub
parent 78bfbf2474
commit 8ef0c874c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 368 additions and 37 deletions

View File

@ -123,6 +123,14 @@ MHLO_PASS_SET = {
"ReduceSumDimIntListEmptyDimModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"MeanModule_basic",
"MeanDynamicSizesModule_basic",
"MeanDimEmptyDimModule_basic",
@ -288,6 +296,14 @@ TOSA_PASS_SET = {
"ElementwiseMinimumIntModule_basic",
"ElementwiseMaximumModule_basic",
"ElementwiseMaximumIntModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"TanhBackward_basic",
"ElementwiseAddModule_basic",
"ReturnThreeTensorFloat32_basic",

View File

@ -231,38 +231,111 @@ public:
// Helper to find the minimum set of dims to collapse with the
// same number of elements as that of collapseDim. This function assumes
// the size of the collapsed dim is never dynamic.
static LogicalResult
minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter,
int64_t collapseDim, int64_t maxCollapseDim,
int64_t startExpandDim, int64_t maxExpandDim,
const SmallVector<int64_t> &collapseShape,
const SmallVector<int64_t> &expandShape,
ReassociationIndices &expandIndices) {
static LogicalResult minimallyCollapseDimHelper(
AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim,
int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim,
SmallVector<int64_t> &collapseShape, SmallVector<int64_t> &expandShape,
ReassociationIndices &collapseIndices,
ReassociationIndices &expandIndices) {
int64_t collapseDimSize = collapseShape[collapseDim];
int64_t expandedSize = 1;
int64_t collapsedSize = collapseDimSize;
int64_t expandIndex = startExpandDim;
int64_t collapseIndex = collapseDim + 1;
if (collapseDimSize == kUnknownSize) {
if (llvm::all_of(collapseShape,
[](int64_t value) { return value == kUnknownSize; }) &&
llvm::all_of(expandShape,
[](int64_t value) { return value == kUnknownSize; })) {
for (int i = 0; i < collapseShape.size(); i++) {
collapseIndices.push_back(i);
}
for (int i = 0; i < expandShape.size(); i++) {
expandIndices.push_back(i);
}
for (auto i : llvm::seq<int64_t>(startExpandDim, maxExpandDim)) {
int64_t expandDimSize = expandShape[i];
if (expandDimSize == kUnknownSize ||
collapseDimSize % (expandedSize *= expandDimSize)) {
return rewriter.notifyMatchFailure(
op, "desired size is not compatible with the input tensor size");
}
expandIndices.push_back(i);
if (expandedSize == collapseDimSize)
return success();
if (expandedSize > collapseDimSize) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only supports expanding and collapsing "
"in view");
}
}
while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) {
if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) {
int64_t expandDimSize = expandShape[expandIndex];
if (expandDimSize != kUnknownSize) {
expandedSize *= expandDimSize;
}
expandIndices.push_back(expandIndex);
expandIndex++;
} else if (collapseIndex != maxCollapseDim &&
collapsedSize < expandedSize) {
collapseDimSize = collapseShape[collapseIndex];
if (collapseDimSize != kUnknownSize) {
collapsedSize *= collapseDimSize;
}
collapseIndices.push_back(collapseIndex);
collapseIndex++;
}
if (expandedSize == collapsedSize)
return success();
}
return rewriter.notifyMatchFailure(
op, "total number of elements mismatch in the expansion");
}
static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
SmallVector<int64_t> &outputShape) {
int64_t inputProduct = 1;
int64_t outputProduct = 1;
int64_t inputDynamicValues = 0;
int64_t outputDynamicValues = 0;
for (int64_t value : inputShape) {
if (value == -1) {
++inputDynamicValues;
} else {
inputProduct *= value;
}
}
for (int64_t value : outputShape) {
if (value == -1) {
++outputDynamicValues;
} else {
outputProduct *= value;
}
}
if (inputDynamicValues + outputDynamicValues == 1) {
if (inputDynamicValues) {
int64_t missingValue = outputProduct / inputProduct;
for (int i = 0; i < inputShape.size(); i++) {
if (inputShape[i] == -1) {
inputShape[i] = missingValue;
break;
}
}
} else {
int64_t missingValue = inputProduct / outputProduct;
for (int i = 0; i < outputShape.size(); i++) {
if (outputShape[i] == -1) {
outputShape[i] = missingValue;
break;
}
}
}
}
return success();
}
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@ -372,7 +445,6 @@ public:
"is enough static shape information to determine its size, or when "
"the input tensor is being flattened to a single dimension");
}
auto productReduceKnownSizes = [](const ArrayRef<int64_t> sizes) {
auto knownSizes = llvm::make_filter_range(
sizes, [](int64_t val) { return val != kUnknownSize; });
@ -411,6 +483,8 @@ public:
SmallVector<int64_t> inputShapeVec = llvm::to_vector(inputShape);
solveDynamicSize(inputShapeVec, outputShape);
// The for loop does the following:
// 1. Attempt to match the indices from inputDim and outputDim to the next
// boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or
@ -441,11 +515,13 @@ public:
bool hasDynamic = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
inputAssociations.emplace_back();
outputAssociations.emplace_back();
// outputDim is next to the boundary
if (outputDim == nextUnchangedOutput - 1) {
if (hasDynamic && inputDim != nextUnchangedInput - 1) {
return rewriter.notifyMatchFailure(
op, "found ambiguous collapse of dynamic input sizes (e.g. "
@ -464,6 +540,7 @@ public:
// inputDim is next to the boundary
if (inputDim == nextUnchangedInput - 1) {
if (hasDynamic && inputShape[inputDim] == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> "
@ -475,6 +552,7 @@ public:
nextUnchangedOutput, inputShapeVec, outputShape,
outputAssociations.back())))
return failure();
outputDim = nextUnchangedOutput;
inputDim = nextUnchangedInput;
continue;
@ -485,6 +563,7 @@ public:
// If the input is dynamic, first assume it is not split
if (inputMatchingDimSize == kUnknownSize) {
checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim],
outputShapeInt[outputDim]);
outputShape[outputDim] = kUnknownSize;
@ -496,15 +575,17 @@ public:
// inputDim size is larger; try to collapse onto it
if (inputMatchingDimSize >= outputMatchingDimSize) {
inputAssociations.back().push_back(inputDim);
if (failed(minimallyCollapseDimHelper(
op, rewriter, inputDim, nextUnchangedInput, outputDim,
nextUnchangedOutput, inputShapeVec, outputShape,
outputAssociations.back())))
inputAssociations.back(), outputAssociations.back()))) {
return failure();
}
hasDynamic = false;
outputDim = outputAssociations.back().back() + 1;
inputDim++;
inputDim = inputAssociations.back().back() + 1;
continue;
}
@ -513,18 +594,25 @@ public:
if (failed(minimallyCollapseDimHelper(
op, rewriter, outputDim, nextUnchangedOutput, inputDim,
nextUnchangedInput, outputShape, inputShapeVec,
inputAssociations.back())))
outputAssociations.back(), inputAssociations.back()))) {
return failure();
}
hasDynamic = false;
inputDim = inputAssociations.back().back() + 1;
outputDim++;
outputDim = outputAssociations.back().back() + 1;
continue;
}
if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) {
return rewriter.notifyMatchFailure(
op, "could not match input tensor shape to output shape; "
"potentially unsupported view shape");
if (inputDim != nextUnchangedInput) {
hasDynamic = true;
if (inputAssociations.size() < 1) {
inputAssociations.emplace_back();
outputAssociations.emplace_back();
}
inputAssociations.back().push_back(inputDim++);
outputAssociations.back().push_back(outputDim++);
continue;
}
// Append the associations for the dims matching `aten.size.int`
@ -537,6 +625,9 @@ public:
}
}
int64_t inputCount = inputAssociations.size();
int64_t outputCount = outputAssociations.size();
// Check if the shapes already match up to dynamic sizes. If so, we can just
// cast as the result type because the previous loop sets up the necessary
// dim checks in case of dynamic sizes.
@ -547,6 +638,7 @@ public:
return indices.size() == 1;
})) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
@ -562,16 +654,25 @@ public:
if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) {
return indices.size() > 1;
})) {
SmallVector<int64_t> intermediateShape;
for (auto i : llvm::seq(0, (int)inputAssociations.size())) {
if (inputAssociations[i].size() > 1) {
intermediateShape.push_back(outputShape[outputAssociations[i][0]]);
} else {
intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]);
for (auto i : llvm::seq(0, (int)outputAssociations.size())) {
int sum = 1;
for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) {
if (outputShape[outputAssociations[i][j]] < 0) {
sum = kUnknownSize;
break;
}
sum *= outputShape[outputAssociations[i][j]];
}
intermediateShape.push_back(sum);
}
Type intermediateResultType =
RankedTensorType::get(intermediateShape, resultType.getElementType());
expandedInput =
rewriter
.create<tensor::CollapseShapeOp>(loc, intermediateResultType,
@ -582,6 +683,7 @@ public:
if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) {
return indices.size() > 1;
})) {
collapsedInput = rewriter
.create<tensor::ExpandShapeOp>(
loc, adjustedResultType,
@ -593,7 +695,9 @@ public:
Value result = collapsedInput.has_value() ? collapsedInput.value()
: expandedInput.value();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};

View File

@ -84,6 +84,25 @@ class ViewExpandOnesMiddleModule(torch.nn.Module):
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))
# ==============================================================================
class ViewCollapseOnesMiddleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 1, 1, 1, 1, 2], torch.float32, True),
])
def forward(self, a):
return a.view(3, 1, 2)
@register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule())
def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1, 1, 1, 2))
# ==============================================================================
class ViewDynamicExpandModule(torch.nn.Module):
@ -240,6 +259,82 @@ def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ViewTwoToThreeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 2], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3)
@register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule())
def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2))
# ==============================================================================
class ViewTwoFiveThreeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 5, 2], torch.float32, True),
])
def forward(self, a):
return a.view(2, 5, 3)
@register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule())
def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2))
# ==============================================================================
class ViewOffsetTestStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 2, 2, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3, 4, 6, 5)
@register_test_case(module_factory=lambda: ViewOffsetTestStaticModule())
def ViewOffsetTestStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 2, 2, 5, 6))
# ==============================================================================
class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 4, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3, 2, 2, 6, 5)
@register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule())
def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5, 6))
# ==============================================================================
class View1DFoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -289,7 +384,7 @@ class ViewExpandInferredDimModule(torch.nn.Module):
])
def forward(self, a):
return a.view(2, -1, 2)
return a.view(3, -1, 2)
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
@ -297,6 +392,44 @@ def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class ViewExpandDynamicDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, -1, 128], torch.float32, True),
])
def forward(self, a):
return a.view(16, 1, 128)
@register_test_case(module_factory=lambda: ViewExpandDynamicDimModule())
def ViewExpandDynamicDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 16, 128))
# ==============================================================================
class ViewFlattenAndExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(0), a.size(1))
@register_test_case(module_factory=lambda: ViewFlattenAndExpandModule())
def ViewFlattenAndExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(64,128))
# ==============================================================================
class UnsafeViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -560,4 +693,4 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
module.forward(tu.rand(2, 4))

View File

@ -82,3 +82,5 @@ func.func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) ->
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

View File

@ -0,0 +1,76 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: func.func @torch.aten.view$twotothree(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<3x2xf32> to tensor<3x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x3xf32> to tensor<2x3xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32>
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
return %1 : !torch.vtensor<[2,3],f32>
}
// CHECK-LABEL: func.func @torch.aten.view$dynamictest(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
return %3 : !torch.vtensor<[?,?],f32>
}
// CHECK-LABEL: func.func @torch.aten.view$dynamicVal(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32>
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<16x1x128xf32> to tensor<16x1x128xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32>
func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
%int128 = torch.constant.int 128
%int1 = torch.constant.int 1
%int16 = torch.constant.int 16
%0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list<int> -> !torch.vtensor<[16,1,128],f32>
return %1 : !torch.vtensor<[16,1,128],f32>
}
// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<2x6xf32> to tensor<2x6xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32>
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<3x2x2xf32> to tensor<3x2x2xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32>
func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>
return %1 : !torch.vtensor<[3,2,2],f32>
}