mirror of https://github.com/llvm/torch-mlir
Add an info cast to `prims.squeeze` decomposition (#3844)
The onnx ingest sometimes has poorly propagated shape information. E.g.: ```mlir ... %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list<int> -> !torch.vtensor<[1,384],f32> return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32> } } ``` This occurred at the boundary of the onnx model `migraphx_bert__bert-large-uncased`. Evidently, the output value tensor info had more information than could be propagated forward. The `PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>` which was causing a type mismatch with the `func.return`.pull/3848/head
parent
a82ba1c422
commit
3cfb7c8df6
|
@ -8958,7 +8958,8 @@ public:
|
||||||
}
|
}
|
||||||
result = *squeezeTensorInfo;
|
result = *squeezeTensorInfo;
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
|
||||||
|
result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue