diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index add32ff05..a609b1791 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -741,15 +741,23 @@ public: // `aten.view`. static std::pair, SmallVector> getInputAndOutputShape(Value inputTorchTensor, - SmallVector outputSizeTorchInt) { + SmallVector outputSizeTorchInt, ArrayRef resultTypeShape) { SmallVector inputShape( inputTorchTensor.getType().cast().getSizes()); SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); - for (auto [outputDim, outputDimSize] : - llvm::enumerate(outputSizeTorchInt)) { +// for (auto [outputDim, outputDimSize] : +// llvm::enumerate(outputSizeTorchInt)) { + for (const auto &it : llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) { + int64_t outputDim = it.index(); + Value outputDimSize = std::get<0>(it.value()); + int64_t resultShapeSize = std::get<1>(it.value()); int64_t inputDim; int64_t outputDimSizeInt; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (resultShapeSize >= 0) { + outputShape[outputDim] = resultShapeSize; + continue; + } if (matchPattern(outputDimSize, m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { outputShape[outputDim] = inputShape[inputDim]; @@ -813,7 +821,7 @@ public: } auto [inputShape, outputShape] = - getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); + getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape()); // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither