mirror of https://github.com/llvm/torch-mlir
[LINALG] Add handling of unknown dimension in size list of `view` op (#633)
The view op allows for the new shape argument to have a -1 value for one of the dimensions, and the op is expected to deduce the size of that dimension by looking at the sizes of the other dimensions and comparing it to the total number of elements in the original tensor. This commit adds this functionality.pull/630/head
parent
1d285f0153
commit
298eeb79ca
|
@ -126,6 +126,44 @@ def View1DFoldModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewCollapseInferredDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-1, 4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule())
|
||||
def ViewCollapseInferredDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandInferredDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, -1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
|
||||
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UnsafeViewExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
@ -3122,8 +3124,10 @@ public:
|
|||
// is violated.
|
||||
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
|
||||
SmallVector<ReassociationIndices> reassociation(collapsedRank);
|
||||
llvm::Optional<int64_t> inferredDimension;
|
||||
for (auto en : llvm::enumerate(outputSizeTorchInt)) {
|
||||
int64_t inputDim;
|
||||
int64_t size;
|
||||
int64_t outputDim = en.index();
|
||||
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
||||
if (matchPattern(en.value(),
|
||||
|
@ -3135,11 +3139,54 @@ public:
|
|||
outputShape[outputDim] = inputShape[inputDim];
|
||||
continue;
|
||||
}
|
||||
} else if (matchPattern(en.value(), m_TorchConstantInt(&size))) {
|
||||
if (size != -1) {
|
||||
outputShape[outputDim] = size;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inferredDimension.hasValue()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "at most one element in size list is allowed to be -1");
|
||||
}
|
||||
inferredDimension = outputDim;
|
||||
}
|
||||
}
|
||||
|
||||
// Use static information of input tensor to determine size of inferred
|
||||
// dimension in output shape.
|
||||
//
|
||||
// If there is an inferred dimension and that is the only dimension
|
||||
// in the output shape (i.e. the tensor is getting fully flattened),
|
||||
// then we don't need to analyze the static information of the input
|
||||
// shape since the reassociation of dimensions only requires rank
|
||||
// information.
|
||||
if (inferredDimension.hasValue() && outputShape.size() > 1) {
|
||||
if (llvm::count(outputShape, kUnknownSize) != 1 ||
|
||||
llvm::count(inputShape, kUnknownSize) != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: an inferred dimension is only supported when there "
|
||||
"is enough static shape information to determine its size, or when "
|
||||
"the input tensor is being flattened to a single dimension");
|
||||
}
|
||||
|
||||
int64_t size;
|
||||
if (matchPattern(en.value(), m_TorchConstantInt(&size)))
|
||||
outputShape[outputDim] = size;
|
||||
auto productReduceKnownSizes = [](const ArrayRef<int64_t> sizes) {
|
||||
auto knownSizes = llvm::make_filter_range(
|
||||
sizes, [](int64_t val) { return val != kUnknownSize; });
|
||||
return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1,
|
||||
std::multiplies<int64_t>());
|
||||
};
|
||||
|
||||
int64_t numOfElements = productReduceKnownSizes(inputShape);
|
||||
int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape);
|
||||
if (numOfElements % outputKnownNumOfElements != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "number of elements in input tensor must be divisible by "
|
||||
"product of non-inferred dimensions in size list");
|
||||
}
|
||||
outputShape[*inferredDimension] =
|
||||
numOfElements / outputKnownNumOfElements;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> collapsedShape =
|
||||
|
|
Loading…
Reference in New Issue