Collapse leading unit dims on input shapes

additionally uses derived output shape as a hint in get input and output
shapes.

Covers some cases like
before:
1,-1
-1,-1
after:
1,-1
1,-1
llama2_wip
dan 2023-09-26 22:38:03 +00:00
parent ff7f8b21dc
commit f70978c034
1 changed files with 48 additions and 9 deletions

View File

@ -284,15 +284,25 @@ public:
// `aten.view`.
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
getInputAndOutputShape(Value inputTorchTensor,
SmallVector<Value> outputSizeTorchInt) {
SmallVector<Value> outputSizeTorchInt, ArrayRef<int64_t> resultTypeShape) {
SmallVector<int64_t> inputShape(
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
//SmallVector<int64_t> outputShape(resultTypeShape.begin(), resultTypeShape.end());
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
for (auto [outputDim, outputDimSize] :
llvm::enumerate(outputSizeTorchInt)) {
for (const auto &it :
llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) {
int64_t outputDim = it.index();
Value outputDimSize = std::get<0>(it.value());
int64_t resultShapeSize = std::get<1>(it.value());
// for (auto [outputDim, outputDimSize] :
// llvm::enumerate(outputSizeTorchInt)) {
int64_t inputDim;
int64_t outputDimSizeInt;
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
if (resultShapeSize >= 0) {
outputShape[outputDim] = resultShapeSize;
continue;
}
if (matchPattern(outputDimSize,
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
outputShape[outputDim] = inputShape[inputDim];
@ -303,7 +313,6 @@ public:
}
}
}
calculateSingleDynamicSize(inputShape, outputShape);
return std::make_pair(inputShape, outputShape);
}
@ -356,8 +365,13 @@ public:
}
auto [inputShape, outputShape] =
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt);
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
llvm::errs() << "inputShape: ";
llvm::interleaveComma(inputShape, llvm::errs());
llvm::errs() << " \noutput shape: ";
llvm::interleaveComma(outputShape, llvm::errs());
// Currently, we only handle the cases where each dimension is either
// being expanded or collapsed. We do not handle cases where it's neither
// collapsing nor expanding like view of [2,3] for 3x2 tensor.
@ -380,6 +394,11 @@ public:
if (matchPattern(outputDimSize,
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
llvm::errs() << "inputDim: ";
llvm::errs() << inputDim;
llvm::errs() << " \noutput Dim: ";
llvm::errs() << outputDim;
}
}
// Mark the end of the input/output shapes
@ -421,10 +440,10 @@ public:
// Used for ensuring that we don't have an ambiguous expansion
bool assumedDynamicDimNotSplit = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
auto inputShapeSlice =
auto inputShapeSlice = // 1, ?, 4096
MutableArrayRef<int64_t>(inputShape)
.slice(inputDim, nextUnchangedInput - inputDim);
auto outputShapeSlice =
auto outputShapeSlice = // ?, 4096
MutableArrayRef<int64_t>(outputShape)
.slice(outputDim, nextUnchangedOutput - outputDim);
SmallVector<int64_t> inputSliceIndices;
@ -472,7 +491,27 @@ public:
inputSliceIndices.push_back(0);
outputSliceIndices.push_back(0);
assumedDynamicDimNotSplit = true;
} else {
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
int64_t idx = 0;
while (idx < inputShapeSlice.size() && inputShapeSlice[idx] == 1) {
inputSliceIndices.push_back(idx++);
}
if (idx < inputShapeSlice.size() && inputShapeSlice[idx] == kUnknownSize) {
inputSliceIndices.push_back(idx);
outputSliceIndices.push_back(0);
assumedDynamicDimNotSplit = true;
}
// inputShape = [2, 1, 1, 1, 1, 3]
// outputShape = [2,3]
// inReassociation = [{0},...]
// outReassociation = [{0},...]
// inputShape = [1 + 1, 1]
// outputShape = [1, 1 + 1]
// inputShape = [1, 1, 2, ?, 2]
// outputShape = [2, ?, 2]
}
if (inputSliceIndices.empty() || outputSliceIndices.empty()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: found unhandled case of expansion/collapse "
"in `aten.view`");