Register dialects in ATen lowering pass

pull/40/head
Marius Brehler 2020-09-09 15:13:12 +00:00 committed by Stella Laurenzo
parent fb2d1a1559
commit 124bc65a70
1 changed files with 4 additions and 0 deletions

View File

@ -891,6 +891,10 @@ MemRefType convertTensorType(TensorType tensor) {
struct ATenLoweringPass struct ATenLoweringPass
: public PassWrapper<ATenLoweringPass, OperationPass<ModuleOp>> { : public PassWrapper<ATenLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, StandardOpsDialect>();
}
void runOnOperation() override { void runOnOperation() override {
LLVMTypeConverter typeConverter(getOperation().getContext()); LLVMTypeConverter typeConverter(getOperation().getContext());
typeConverter.addConversion([&](Type type) { typeConverter.addConversion([&](Type type) {