mirror of https://github.com/llvm/torch-mlir
Implement Expand/Collapse Functionality for Aten.View (#1353)
parent
78bfbf2474
commit
8ef0c874c2
|
@ -123,6 +123,14 @@ MHLO_PASS_SET = {
|
|||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"ViewDoubleMergeStaticModule_basic",
|
||||
"ViewCollapseOnesMiddleModule_basic",
|
||||
"ViewFiveTestStaticModule_basic",
|
||||
"ViewOffsetTestStaticModule_basic",
|
||||
"ViewTwoFiveThreeStaticModule_basic",
|
||||
"ViewTwoToThreeStaticModule_basic",
|
||||
"ViewExpandOnesMiddleOppModule_basic",
|
||||
"ViewOffsetBackwardTestStaticModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
|
@ -288,6 +296,14 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseMinimumIntModule_basic",
|
||||
"ElementwiseMaximumModule_basic",
|
||||
"ElementwiseMaximumIntModule_basic",
|
||||
"ViewDoubleMergeStaticModule_basic",
|
||||
"ViewCollapseOnesMiddleModule_basic",
|
||||
"ViewFiveTestStaticModule_basic",
|
||||
"ViewOffsetTestStaticModule_basic",
|
||||
"ViewTwoFiveThreeStaticModule_basic",
|
||||
"ViewTwoToThreeStaticModule_basic",
|
||||
"ViewExpandOnesMiddleOppModule_basic",
|
||||
"ViewOffsetBackwardTestStaticModule_basic",
|
||||
"TanhBackward_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
|
|
|
@ -231,38 +231,111 @@ public:
|
|||
// Helper to find the minimum set of dims to collapse with the
|
||||
// same number of elements as that of collapseDim. This function assumes
|
||||
// the size of the collapsed dim is never dynamic.
|
||||
static LogicalResult
|
||||
minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter,
|
||||
int64_t collapseDim, int64_t maxCollapseDim,
|
||||
int64_t startExpandDim, int64_t maxExpandDim,
|
||||
const SmallVector<int64_t> &collapseShape,
|
||||
const SmallVector<int64_t> &expandShape,
|
||||
ReassociationIndices &expandIndices) {
|
||||
static LogicalResult minimallyCollapseDimHelper(
|
||||
AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim,
|
||||
int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim,
|
||||
SmallVector<int64_t> &collapseShape, SmallVector<int64_t> &expandShape,
|
||||
ReassociationIndices &collapseIndices,
|
||||
ReassociationIndices &expandIndices) {
|
||||
|
||||
int64_t collapseDimSize = collapseShape[collapseDim];
|
||||
|
||||
int64_t expandedSize = 1;
|
||||
int64_t collapsedSize = collapseDimSize;
|
||||
|
||||
int64_t expandIndex = startExpandDim;
|
||||
int64_t collapseIndex = collapseDim + 1;
|
||||
|
||||
if (collapseDimSize == kUnknownSize) {
|
||||
if (llvm::all_of(collapseShape,
|
||||
[](int64_t value) { return value == kUnknownSize; }) &&
|
||||
llvm::all_of(expandShape,
|
||||
[](int64_t value) { return value == kUnknownSize; })) {
|
||||
|
||||
for (int i = 0; i < collapseShape.size(); i++) {
|
||||
collapseIndices.push_back(i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < expandShape.size(); i++) {
|
||||
expandIndices.push_back(i);
|
||||
}
|
||||
|
||||
for (auto i : llvm::seq<int64_t>(startExpandDim, maxExpandDim)) {
|
||||
int64_t expandDimSize = expandShape[i];
|
||||
if (expandDimSize == kUnknownSize ||
|
||||
collapseDimSize % (expandedSize *= expandDimSize)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "desired size is not compatible with the input tensor size");
|
||||
}
|
||||
expandIndices.push_back(i);
|
||||
if (expandedSize == collapseDimSize)
|
||||
return success();
|
||||
|
||||
if (expandedSize > collapseDimSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only supports expanding and collapsing "
|
||||
"in view");
|
||||
}
|
||||
}
|
||||
|
||||
while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) {
|
||||
if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) {
|
||||
int64_t expandDimSize = expandShape[expandIndex];
|
||||
if (expandDimSize != kUnknownSize) {
|
||||
expandedSize *= expandDimSize;
|
||||
}
|
||||
expandIndices.push_back(expandIndex);
|
||||
expandIndex++;
|
||||
|
||||
} else if (collapseIndex != maxCollapseDim &&
|
||||
collapsedSize < expandedSize) {
|
||||
collapseDimSize = collapseShape[collapseIndex];
|
||||
if (collapseDimSize != kUnknownSize) {
|
||||
collapsedSize *= collapseDimSize;
|
||||
}
|
||||
collapseIndices.push_back(collapseIndex);
|
||||
collapseIndex++;
|
||||
}
|
||||
|
||||
if (expandedSize == collapsedSize)
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "total number of elements mismatch in the expansion");
|
||||
}
|
||||
|
||||
static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
|
||||
SmallVector<int64_t> &outputShape) {
|
||||
int64_t inputProduct = 1;
|
||||
int64_t outputProduct = 1;
|
||||
|
||||
int64_t inputDynamicValues = 0;
|
||||
int64_t outputDynamicValues = 0;
|
||||
|
||||
for (int64_t value : inputShape) {
|
||||
if (value == -1) {
|
||||
++inputDynamicValues;
|
||||
} else {
|
||||
inputProduct *= value;
|
||||
}
|
||||
}
|
||||
for (int64_t value : outputShape) {
|
||||
if (value == -1) {
|
||||
++outputDynamicValues;
|
||||
} else {
|
||||
outputProduct *= value;
|
||||
}
|
||||
}
|
||||
|
||||
if (inputDynamicValues + outputDynamicValues == 1) {
|
||||
if (inputDynamicValues) {
|
||||
int64_t missingValue = outputProduct / inputProduct;
|
||||
for (int i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] == -1) {
|
||||
inputShape[i] = missingValue;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int64_t missingValue = inputProduct / outputProduct;
|
||||
for (int i = 0; i < outputShape.size(); i++) {
|
||||
if (outputShape[i] == -1) {
|
||||
outputShape[i] = missingValue;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
@ -372,7 +445,6 @@ public:
|
|||
"is enough static shape information to determine its size, or when "
|
||||
"the input tensor is being flattened to a single dimension");
|
||||
}
|
||||
|
||||
auto productReduceKnownSizes = [](const ArrayRef<int64_t> sizes) {
|
||||
auto knownSizes = llvm::make_filter_range(
|
||||
sizes, [](int64_t val) { return val != kUnknownSize; });
|
||||
|
@ -411,6 +483,8 @@ public:
|
|||
|
||||
SmallVector<int64_t> inputShapeVec = llvm::to_vector(inputShape);
|
||||
|
||||
solveDynamicSize(inputShapeVec, outputShape);
|
||||
|
||||
// The for loop does the following:
|
||||
// 1. Attempt to match the indices from inputDim and outputDim to the next
|
||||
// boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or
|
||||
|
@ -441,11 +515,13 @@ public:
|
|||
|
||||
bool hasDynamic = false;
|
||||
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
|
||||
|
||||
inputAssociations.emplace_back();
|
||||
outputAssociations.emplace_back();
|
||||
|
||||
// outputDim is next to the boundary
|
||||
if (outputDim == nextUnchangedOutput - 1) {
|
||||
|
||||
if (hasDynamic && inputDim != nextUnchangedInput - 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "found ambiguous collapse of dynamic input sizes (e.g. "
|
||||
|
@ -464,6 +540,7 @@ public:
|
|||
|
||||
// inputDim is next to the boundary
|
||||
if (inputDim == nextUnchangedInput - 1) {
|
||||
|
||||
if (hasDynamic && inputShape[inputDim] == kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> "
|
||||
|
@ -475,6 +552,7 @@ public:
|
|||
nextUnchangedOutput, inputShapeVec, outputShape,
|
||||
outputAssociations.back())))
|
||||
return failure();
|
||||
|
||||
outputDim = nextUnchangedOutput;
|
||||
inputDim = nextUnchangedInput;
|
||||
continue;
|
||||
|
@ -485,6 +563,7 @@ public:
|
|||
|
||||
// If the input is dynamic, first assume it is not split
|
||||
if (inputMatchingDimSize == kUnknownSize) {
|
||||
|
||||
checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim],
|
||||
outputShapeInt[outputDim]);
|
||||
outputShape[outputDim] = kUnknownSize;
|
||||
|
@ -496,15 +575,17 @@ public:
|
|||
|
||||
// inputDim size is larger; try to collapse onto it
|
||||
if (inputMatchingDimSize >= outputMatchingDimSize) {
|
||||
|
||||
inputAssociations.back().push_back(inputDim);
|
||||
if (failed(minimallyCollapseDimHelper(
|
||||
op, rewriter, inputDim, nextUnchangedInput, outputDim,
|
||||
nextUnchangedOutput, inputShapeVec, outputShape,
|
||||
outputAssociations.back())))
|
||||
inputAssociations.back(), outputAssociations.back()))) {
|
||||
return failure();
|
||||
}
|
||||
hasDynamic = false;
|
||||
outputDim = outputAssociations.back().back() + 1;
|
||||
inputDim++;
|
||||
inputDim = inputAssociations.back().back() + 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -513,18 +594,25 @@ public:
|
|||
if (failed(minimallyCollapseDimHelper(
|
||||
op, rewriter, outputDim, nextUnchangedOutput, inputDim,
|
||||
nextUnchangedInput, outputShape, inputShapeVec,
|
||||
inputAssociations.back())))
|
||||
outputAssociations.back(), inputAssociations.back()))) {
|
||||
|
||||
return failure();
|
||||
}
|
||||
hasDynamic = false;
|
||||
inputDim = inputAssociations.back().back() + 1;
|
||||
outputDim++;
|
||||
outputDim = outputAssociations.back().back() + 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "could not match input tensor shape to output shape; "
|
||||
"potentially unsupported view shape");
|
||||
if (inputDim != nextUnchangedInput) {
|
||||
hasDynamic = true;
|
||||
if (inputAssociations.size() < 1) {
|
||||
inputAssociations.emplace_back();
|
||||
outputAssociations.emplace_back();
|
||||
}
|
||||
inputAssociations.back().push_back(inputDim++);
|
||||
outputAssociations.back().push_back(outputDim++);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Append the associations for the dims matching `aten.size.int`
|
||||
|
@ -537,6 +625,9 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
int64_t inputCount = inputAssociations.size();
|
||||
int64_t outputCount = outputAssociations.size();
|
||||
|
||||
// Check if the shapes already match up to dynamic sizes. If so, we can just
|
||||
// cast as the result type because the previous loop sets up the necessary
|
||||
// dim checks in case of dynamic sizes.
|
||||
|
@ -547,6 +638,7 @@ public:
|
|||
return indices.size() == 1;
|
||||
})) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -562,16 +654,25 @@ public:
|
|||
if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) {
|
||||
return indices.size() > 1;
|
||||
})) {
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
for (auto i : llvm::seq(0, (int)inputAssociations.size())) {
|
||||
if (inputAssociations[i].size() > 1) {
|
||||
intermediateShape.push_back(outputShape[outputAssociations[i][0]]);
|
||||
} else {
|
||||
intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]);
|
||||
for (auto i : llvm::seq(0, (int)outputAssociations.size())) {
|
||||
int sum = 1;
|
||||
|
||||
for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) {
|
||||
if (outputShape[outputAssociations[i][j]] < 0) {
|
||||
sum = kUnknownSize;
|
||||
break;
|
||||
}
|
||||
sum *= outputShape[outputAssociations[i][j]];
|
||||
}
|
||||
|
||||
intermediateShape.push_back(sum);
|
||||
}
|
||||
|
||||
Type intermediateResultType =
|
||||
RankedTensorType::get(intermediateShape, resultType.getElementType());
|
||||
|
||||
expandedInput =
|
||||
rewriter
|
||||
.create<tensor::CollapseShapeOp>(loc, intermediateResultType,
|
||||
|
@ -582,6 +683,7 @@ public:
|
|||
if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) {
|
||||
return indices.size() > 1;
|
||||
})) {
|
||||
|
||||
collapsedInput = rewriter
|
||||
.create<tensor::ExpandShapeOp>(
|
||||
loc, adjustedResultType,
|
||||
|
@ -593,7 +695,9 @@ public:
|
|||
|
||||
Value result = collapsedInput.has_value() ? collapsedInput.value()
|
||||
: expandedInput.value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -84,6 +84,25 @@ class ViewExpandOnesMiddleModule(torch.nn.Module):
|
|||
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewCollapseOnesMiddleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 1, 1, 1, 1, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, 1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule())
|
||||
def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1, 1, 1, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewDynamicExpandModule(torch.nn.Module):
|
||||
|
@ -240,6 +259,82 @@ def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewTwoToThreeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule())
|
||||
def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewTwoFiveThreeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 5, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 5, 3)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule())
|
||||
def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewOffsetTestStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 2, 2, 5, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3, 4, 6, 5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewOffsetTestStaticModule())
|
||||
def ViewOffsetTestStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 2, 2, 5, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 4, 5, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3, 2, 2, 6, 5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule())
|
||||
def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4, 5, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class View1DFoldModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -289,7 +384,7 @@ class ViewExpandInferredDimModule(torch.nn.Module):
|
|||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, -1, 2)
|
||||
return a.view(3, -1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
|
||||
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
|
||||
|
@ -297,6 +392,44 @@ def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandDynamicDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -1, 128], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(16, 1, 128)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandDynamicDimModule())
|
||||
def ViewExpandDynamicDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 16, 128))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewFlattenAndExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(a.size(0), a.size(1))
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewFlattenAndExpandModule())
|
||||
def ViewFlattenAndExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(64,128))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UnsafeViewExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -560,4 +693,4 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
|
||||
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
module.forward(tu.rand(2, 4))
|
|
@ -82,3 +82,5 @@ func.func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) ->
|
|||
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$twotothree(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<3x2xf32> to tensor<3x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
|
||||
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x3xf32> to tensor<2x3xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32>
|
||||
|
||||
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
|
||||
return %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamicVal(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32>
|
||||
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<16x1x128xf32> to tensor<16x1x128xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
|
||||
%int128 = torch.constant.int 128
|
||||
%int1 = torch.constant.int 1
|
||||
%int16 = torch.constant.int 16
|
||||
%0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list<int> -> !torch.vtensor<[16,1,128],f32>
|
||||
return %1 : !torch.vtensor<[16,1,128],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<2x6xf32> to tensor<2x6xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32>
|
||||
// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<3x2x2xf32> to tensor<3x2x2xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32>
|
||||
|
||||
func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>
|
||||
return %1 : !torch.vtensor<[3,2,2],f32>
|
||||
}
|
Loading…
Reference in New Issue