diff --git a/e2e_testing/torchscript/reshape_like.py b/e2e_testing/torchscript/reshape_like.py index 67ca98d93..e998b806e 100644 --- a/e2e_testing/torchscript/reshape_like.py +++ b/e2e_testing/torchscript/reshape_like.py @@ -126,6 +126,44 @@ def View1DFoldModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewCollapseInferredDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + + def forward(self, a): + return a.view(-1, 4) + +@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule()) +def ViewCollapseInferredDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class ViewExpandInferredDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 6], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, -1, 2) + +@register_test_case(module_factory=lambda: ViewExpandInferredDimModule()) +def ViewExpandInferredDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6)) + +# ============================================================================== + class UnsafeViewExpandModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 21558a19b..a4719bdd5 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -28,6 +28,8 @@ #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -3122,8 +3124,10 @@ public: // is violated. SmallVector outputShape(resultRank, kUnknownSize); SmallVector reassociation(collapsedRank); + llvm::Optional inferredDimension; for (auto en : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; + int64_t size; int64_t outputDim = en.index(); // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(en.value(), @@ -3135,11 +3139,54 @@ public: outputShape[outputDim] = inputShape[inputDim]; continue; } + } else if (matchPattern(en.value(), m_TorchConstantInt(&size))) { + if (size != -1) { + outputShape[outputDim] = size; + continue; + } + + if (inferredDimension.hasValue()) { + return rewriter.notifyMatchFailure( + op, "at most one element in size list is allowed to be -1"); + } + inferredDimension = outputDim; + } + } + + // Use static information of input tensor to determine size of inferred + // dimension in output shape. + // + // If there is an inferred dimension and that is the only dimension + // in the output shape (i.e. the tensor is getting fully flattened), + // then we don't need to analyze the static information of the input + // shape since the reassociation of dimensions only requires rank + // information. + if (inferredDimension.hasValue() && outputShape.size() > 1) { + if (llvm::count(outputShape, kUnknownSize) != 1 || + llvm::count(inputShape, kUnknownSize) != 0) { + return rewriter.notifyMatchFailure( + op, + "unimplemented: an inferred dimension is only supported when there " + "is enough static shape information to determine its size, or when " + "the input tensor is being flattened to a single dimension"); } - int64_t size; - if (matchPattern(en.value(), m_TorchConstantInt(&size))) - outputShape[outputDim] = size; + auto productReduceKnownSizes = [](const ArrayRef sizes) { + auto knownSizes = llvm::make_filter_range( + sizes, [](int64_t val) { return val != kUnknownSize; }); + return std::accumulate(knownSizes.begin(), knownSizes.end(), /*init=*/1, + std::multiplies()); + }; + + int64_t numOfElements = productReduceKnownSizes(inputShape); + int64_t outputKnownNumOfElements = productReduceKnownSizes(outputShape); + if (numOfElements % outputKnownNumOfElements != 0) { + return rewriter.notifyMatchFailure( + op, "number of elements in input tensor must be divisible by " + "product of non-inferred dimensions in size list"); + } + outputShape[*inferredDimension] = + numOfElements / outputKnownNumOfElements; } SmallVector collapsedShape =