mirror of https://github.com/llvm/torch-mlir
Collapse leading unit dims on input shapes
additionally uses derived output shape as a hint in get input and output shapes. Covers some cases like before: 1,-1 -1,-1 after: 1,-1 1,-1llama2_wip
parent
ff7f8b21dc
commit
f70978c034
|
@ -284,15 +284,25 @@ public:
|
|||
// `aten.view`.
|
||||
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
|
||||
getInputAndOutputShape(Value inputTorchTensor,
|
||||
SmallVector<Value> outputSizeTorchInt) {
|
||||
SmallVector<Value> outputSizeTorchInt, ArrayRef<int64_t> resultTypeShape) {
|
||||
SmallVector<int64_t> inputShape(
|
||||
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
||||
//SmallVector<int64_t> outputShape(resultTypeShape.begin(), resultTypeShape.end());
|
||||
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
||||
for (auto [outputDim, outputDimSize] :
|
||||
llvm::enumerate(outputSizeTorchInt)) {
|
||||
for (const auto &it :
|
||||
llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) {
|
||||
int64_t outputDim = it.index();
|
||||
Value outputDimSize = std::get<0>(it.value());
|
||||
int64_t resultShapeSize = std::get<1>(it.value());
|
||||
// for (auto [outputDim, outputDimSize] :
|
||||
// llvm::enumerate(outputSizeTorchInt)) {
|
||||
int64_t inputDim;
|
||||
int64_t outputDimSizeInt;
|
||||
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
||||
if (resultShapeSize >= 0) {
|
||||
outputShape[outputDim] = resultShapeSize;
|
||||
continue;
|
||||
}
|
||||
if (matchPattern(outputDimSize,
|
||||
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
|
||||
outputShape[outputDim] = inputShape[inputDim];
|
||||
|
@ -303,7 +313,6 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
calculateSingleDynamicSize(inputShape, outputShape);
|
||||
return std::make_pair(inputShape, outputShape);
|
||||
}
|
||||
|
@ -356,8 +365,13 @@ public:
|
|||
}
|
||||
|
||||
auto [inputShape, outputShape] =
|
||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt);
|
||||
|
||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
||||
llvm::errs() << "inputShape: ";
|
||||
llvm::interleaveComma(inputShape, llvm::errs());
|
||||
|
||||
llvm::errs() << " \noutput shape: ";
|
||||
llvm::interleaveComma(outputShape, llvm::errs());
|
||||
|
||||
// Currently, we only handle the cases where each dimension is either
|
||||
// being expanded or collapsed. We do not handle cases where it's neither
|
||||
// collapsing nor expanding like view of [2,3] for 3x2 tensor.
|
||||
|
@ -380,6 +394,11 @@ public:
|
|||
if (matchPattern(outputDimSize,
|
||||
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
|
||||
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
|
||||
|
||||
llvm::errs() << "inputDim: ";
|
||||
llvm::errs() << inputDim;
|
||||
llvm::errs() << " \noutput Dim: ";
|
||||
llvm::errs() << outputDim;
|
||||
}
|
||||
}
|
||||
// Mark the end of the input/output shapes
|
||||
|
@ -421,10 +440,10 @@ public:
|
|||
// Used for ensuring that we don't have an ambiguous expansion
|
||||
bool assumedDynamicDimNotSplit = false;
|
||||
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
|
||||
auto inputShapeSlice =
|
||||
auto inputShapeSlice = // 1, ?, 4096
|
||||
MutableArrayRef<int64_t>(inputShape)
|
||||
.slice(inputDim, nextUnchangedInput - inputDim);
|
||||
auto outputShapeSlice =
|
||||
auto outputShapeSlice = // ?, 4096
|
||||
MutableArrayRef<int64_t>(outputShape)
|
||||
.slice(outputDim, nextUnchangedOutput - outputDim);
|
||||
SmallVector<int64_t> inputSliceIndices;
|
||||
|
@ -472,7 +491,27 @@ public:
|
|||
inputSliceIndices.push_back(0);
|
||||
outputSliceIndices.push_back(0);
|
||||
assumedDynamicDimNotSplit = true;
|
||||
} else {
|
||||
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
|
||||
int64_t idx = 0;
|
||||
while (idx < inputShapeSlice.size() && inputShapeSlice[idx] == 1) {
|
||||
inputSliceIndices.push_back(idx++);
|
||||
}
|
||||
if (idx < inputShapeSlice.size() && inputShapeSlice[idx] == kUnknownSize) {
|
||||
inputSliceIndices.push_back(idx);
|
||||
outputSliceIndices.push_back(0);
|
||||
assumedDynamicDimNotSplit = true;
|
||||
}
|
||||
// inputShape = [2, 1, 1, 1, 1, 3]
|
||||
// outputShape = [2,3]
|
||||
// inReassociation = [{0},...]
|
||||
// outReassociation = [{0},...]
|
||||
// inputShape = [1 + 1, 1]
|
||||
// outputShape = [1, 1 + 1]
|
||||
// inputShape = [1, 1, 2, ?, 2]
|
||||
// outputShape = [2, ?, 2]
|
||||
}
|
||||
|
||||
if (inputSliceIndices.empty() || outputSliceIndices.empty()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: found unhandled case of expansion/collapse "
|
||||
"in `aten.view`");
|
||||
|
|
Loading…
Reference in New Issue