diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 9ec6a6006..452215b35 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -284,15 +284,25 @@ public: // `aten.view`. static std::pair, SmallVector> getInputAndOutputShape(Value inputTorchTensor, - SmallVector outputSizeTorchInt) { + SmallVector outputSizeTorchInt, ArrayRef resultTypeShape) { SmallVector inputShape( inputTorchTensor.getType().cast().getSizes()); + //SmallVector outputShape(resultTypeShape.begin(), resultTypeShape.end()); SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); - for (auto [outputDim, outputDimSize] : - llvm::enumerate(outputSizeTorchInt)) { + for (const auto &it : + llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) { + int64_t outputDim = it.index(); + Value outputDimSize = std::get<0>(it.value()); + int64_t resultShapeSize = std::get<1>(it.value()); + // for (auto [outputDim, outputDimSize] : + // llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; int64_t outputDimSizeInt; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (resultShapeSize >= 0) { + outputShape[outputDim] = resultShapeSize; + continue; + } if (matchPattern(outputDimSize, m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { outputShape[outputDim] = inputShape[inputDim]; @@ -303,7 +313,6 @@ public: } } } - calculateSingleDynamicSize(inputShape, outputShape); return std::make_pair(inputShape, outputShape); } @@ -356,8 +365,13 @@ public: } auto [inputShape, outputShape] = - getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); - + getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape()); + llvm::errs() << "inputShape: "; + llvm::interleaveComma(inputShape, llvm::errs()); + + llvm::errs() << " \noutput shape: "; + llvm::interleaveComma(outputShape, llvm::errs()); + // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. @@ -380,6 +394,11 @@ public: if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { unchangedDims.push_back(std::make_pair(inputDim, outputDim)); + + llvm::errs() << "inputDim: "; + llvm::errs() << inputDim; + llvm::errs() << " \noutput Dim: "; + llvm::errs() << outputDim; } } // Mark the end of the input/output shapes @@ -421,10 +440,10 @@ public: // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { - auto inputShapeSlice = + auto inputShapeSlice = // 1, ?, 4096 MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); - auto outputShapeSlice = + auto outputShapeSlice = // ?, 4096 MutableArrayRef(outputShape) .slice(outputDim, nextUnchangedOutput - outputDim); SmallVector inputSliceIndices; @@ -472,7 +491,27 @@ public: inputSliceIndices.push_back(0); outputSliceIndices.push_back(0); assumedDynamicDimNotSplit = true; - } else { + } else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) { + int64_t idx = 0; + while (idx < inputShapeSlice.size() && inputShapeSlice[idx] == 1) { + inputSliceIndices.push_back(idx++); + } + if (idx < inputShapeSlice.size() && inputShapeSlice[idx] == kUnknownSize) { + inputSliceIndices.push_back(idx); + outputSliceIndices.push_back(0); + assumedDynamicDimNotSplit = true; + } + // inputShape = [2, 1, 1, 1, 1, 3] + // outputShape = [2,3] + // inReassociation = [{0},...] + // outReassociation = [{0},...] + // inputShape = [1 + 1, 1] + // outputShape = [1, 1 + 1] + // inputShape = [1, 1, 2, ?, 2] + // outputShape = [2, ?, 2] + } + + if (inputSliceIndices.empty() || outputSliceIndices.empty()) { return rewriter.notifyMatchFailure( op, "unimplemented: found unhandled case of expansion/collapse " "in `aten.view`");