Fix `verifyLinalgCompatibleTypes` which currently doesn't successfully catch `torch.tensor`. (#947)

pull/792/merge snapshot-20220616.505
Maksim Levental 2022-06-15 18:21:36 -05:00 committed by GitHub
parent 77ab31641f
commit a34dad2e07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -26,6 +26,8 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op,
// TODO: Remove this check but use a separate verification pass to verify the
// invariants expected by later passes.
auto isValidLinalgType = [](Type type) {
if (type.isa<NonValueTensorType>())
return false;
auto tensor = type.dyn_cast<ValueTensorType>();
return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();