mirror of https://github.com/llvm/torch-mlir
view hack
parent
bfcf93ea21
commit
484ee82c28
|
@ -741,15 +741,23 @@ public:
|
|||
// `aten.view`.
|
||||
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
|
||||
getInputAndOutputShape(Value inputTorchTensor,
|
||||
SmallVector<Value> outputSizeTorchInt) {
|
||||
SmallVector<Value> outputSizeTorchInt, ArrayRef<int64_t> resultTypeShape) {
|
||||
SmallVector<int64_t> inputShape(
|
||||
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
||||
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
||||
for (auto [outputDim, outputDimSize] :
|
||||
llvm::enumerate(outputSizeTorchInt)) {
|
||||
// for (auto [outputDim, outputDimSize] :
|
||||
// llvm::enumerate(outputSizeTorchInt)) {
|
||||
for (const auto &it : llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) {
|
||||
int64_t outputDim = it.index();
|
||||
Value outputDimSize = std::get<0>(it.value());
|
||||
int64_t resultShapeSize = std::get<1>(it.value());
|
||||
int64_t inputDim;
|
||||
int64_t outputDimSizeInt;
|
||||
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
||||
if (resultShapeSize >= 0) {
|
||||
outputShape[outputDim] = resultShapeSize;
|
||||
continue;
|
||||
}
|
||||
if (matchPattern(outputDimSize,
|
||||
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
|
||||
outputShape[outputDim] = inputShape[inputDim];
|
||||
|
@ -813,7 +821,7 @@ public:
|
|||
}
|
||||
|
||||
auto [inputShape, outputShape] =
|
||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt);
|
||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
||||
|
||||
// Currently, we only handle the cases where each dimension is either
|
||||
// being expanded or collapsed. We do not handle cases where it's neither
|
||||
|
|
Loading…
Reference in New Issue