[LINALG] Add handling of unknown dimension in size list of `view` op (#633)

The view op allows for the new shape argument to have a -1 value for
one of the dimensions, and the op is expected to deduce the size of
that dimension by looking at the sizes of the other dimensions and
comparing it to the total number of elements in the original
tensor. This commit adds this functionality.
pull/630/head
Ramiro Leal-Cavazos 2022-03-02 13:35:01 -08:00 committed by GitHub
parent 1d285f0153
commit 298eeb79ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 3 deletions

View File

@ -126,6 +126,44 @@ def View1DFoldModule_basic(module, tu: TestUtils):
# ==============================================================================
class ViewCollapseInferredDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 4], torch.float32, True),
])
def forward(self, a):
return a.view(-1, 4)
@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule())
def ViewCollapseInferredDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class ViewExpandInferredDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 6], torch.float32, True),
])
def forward(self, a):
return a.view(2, -1, 2)
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6))
# ==============================================================================
class UnsafeViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -28,6 +28,8 @@
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <numeric>
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
@ -3122,8 +3124,10 @@ public:
// is violated.
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
SmallVector<ReassociationIndices> reassociation(collapsedRank);
llvm::Optional<int64_t> inferredDimension;
for (auto en : llvm::enumerate(outputSizeTorchInt)) {
int64_t inputDim;
int64_t size;
int64_t outputDim = en.index();
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
if (matchPattern(en.value(),
@ -3135,11 +3139,54 @@ public:
outputShape[outputDim] = inputShape[inputDim];
continue;
}
} else if (matchPattern(en.value(), m_TorchConstantInt(&size))) {
if (size != -1) {
outputShape[outputDim] = size;
continue;
}
if (inferredDimension.hasValue()) {
return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1");
}
inferredDimension = outputDim;
}
}
// Use static information of input tensor to determine size of inferred
// dimension in output shape.
//
// If there is an inferred dimension and that is the only dimension
// in the output shape (i.e. the tensor is getting fully flattened),
// then we don't need to analyze the static information of the input
// shape since the reassociation of dimensions only requires rank
// information.
if (inferredDimension.hasValue() && outputShape.size() > 1) {
if (llvm::count(outputShape, kUnknownSize) != 1 ||
llvm::count(inputShape, kUnknownSize) != 0) {
return rewriter.notifyMatchFailure(
op,
"unimplemented: an inferred dimension is only supported when there "
"is enough static shape information to determine its size, or when "
"the input tensor is being flattened to a single dimension");
}
int64_t size;
if (matchPattern(en.value(), m_TorchConstantInt(&size)))
outputShape[outputDim] = size;
auto productReduceKnownSizes = [](const ArrayRef<int64_t> sizes) {
auto knownSizes = llvm::make_filter_range(
sizes, [](int64_t val) { return val != kUnknownSize; });
return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1,
std::multiplies<int64_t>());
};
int64_t numOfElements = productReduceKnownSizes(inputShape);
int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape);
if (numOfElements % outputKnownNumOfElements != 0) {
return rewriter.notifyMatchFailure(
op, "number of elements in input tensor must be divisible by "
"product of non-inferred dimensions in size list");
}
outputShape[*inferredDimension] =
numOfElements / outputKnownNumOfElements;
}
SmallVector<int64_t> collapsedShape =