mirror of https://github.com/llvm/torch-mlir
Add final cast to TorchToLinalg conversions missing it (#692)
In order to make sure that the TorchToLinalg conversions leave the graph in a valid state, the final result of the conversion has to be casted to the result type of the op. This commit adds this cast to ops that did not have it.pull/641/head snapshot-20220323.342
parent
f7c7bb800c
commit
e966112c8d
|
@ -1006,7 +1006,9 @@ public:
|
|||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, adaptor.self());
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, adaptor.self());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1074,7 +1076,8 @@ public:
|
|||
})
|
||||
->getResult(0);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1538,7 +1538,7 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOp(op, finalRes);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue