Register dialects in E2E passes

pull/41/head
Marius Brehler 2020-09-10 07:23:46 +00:00 committed by Marius Brehler
parent a2fb68059f
commit 843448cde9
5 changed files with 41 additions and 8 deletions

View File

@ -183,7 +183,11 @@ public:
namespace {
class ResolveTensorLoadStoreOps
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
@ -291,7 +295,11 @@ public:
namespace {
class LowerAllocMemRefOps
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
@ -219,7 +220,11 @@ public:
// step.
namespace {
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<npcomprt::NpcomprtDialect>();
}
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

View File

@ -139,7 +139,11 @@ public:
namespace {
class LowerBroadcastToToLoops
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
}
void runOnOperation() override {
auto func = getOperation();
MLIRContext *context = &getContext();
ConversionTarget target(*context);
@ -257,7 +261,11 @@ namespace {
class LowerLinalgOnTensorToLinalgOnMemref
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
LowerLinalgOnTensorToLinalgOnMemref> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
}
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
@ -351,7 +359,11 @@ GlobalCreator::GlobalCreator(ModuleOp module) {
namespace {
class LowerConstantTensorsToMemrefs
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tcp::TCPDialect>();
}
void runOnOperation() override {
auto module = getOperation();
GlobalCreator globals(module);

View File

@ -637,7 +637,11 @@ static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
namespace {
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();

View File

@ -227,7 +227,11 @@ namespace {
// This pass lowers the public ABI of the module to the primitives exposed by
// the npcomprt dialect.
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
void runOnOperation() {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, npcomprt::NpcomprtDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
// Before we lower anything, capture any needed metadata about the argument