Add final cast to TorchToLinalg conversions missing it (#692)

In order to make sure that the TorchToLinalg conversions leave the
graph in a valid state, the final result of the conversion has to be
casted to the result type of the op. This commit adds this cast to ops
that did not have it.
pull/641/head snapshot-20220323.342
Ramiro Leal-Cavazos 2022-03-23 13:52:32 -07:00 committed by GitHub
parent f7c7bb800c
commit e966112c8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -1006,7 +1006,9 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
rewriter.replaceOp(op, adaptor.self());
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, adaptor.self());
return success();
}
};
@ -1074,7 +1076,8 @@ public:
})
->getResult(0);
rewriter.replaceOp(op, result);
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};

View File

@ -1538,7 +1538,7 @@ public:
})
.getResult(0);
rewriter.replaceOp(op, finalRes);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};