mirror of https://github.com/llvm/torch-mlir
Add an info cast to `prims.squeeze` decomposition (#3844)
The onnx ingest sometimes has poorly propagated shape information. E.g.: ```mlir ... %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list<int> -> !torch.vtensor<[1,384],f32> return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32> } } ``` This occurred at the boundary of the onnx model `migraphx_bert__bert-large-uncased`. Evidently, the output value tensor info had more information than could be propagated forward. The `PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>` which was causing a type mismatch with the `func.return`.pull/3848/head
parent
a82ba1c422
commit
3cfb7c8df6
|
@ -8958,7 +8958,8 @@ public:
|
|||
}
|
||||
result = *squeezeTensorInfo;
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
|
||||
result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue