mirror of https://github.com/llvm/torch-mlir
Register dialects in E2E passes
parent
a2fb68059f
commit
843448cde9
|
@ -183,7 +183,11 @@ public:
|
|||
namespace {
|
||||
class ResolveTensorLoadStoreOps
|
||||
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
|
@ -291,7 +295,11 @@ public:
|
|||
namespace {
|
||||
class LowerAllocMemRefOps
|
||||
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
|
@ -219,7 +220,11 @@ public:
|
|||
// step.
|
||||
namespace {
|
||||
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<npcomprt::NpcomprtDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
|
|
|
@ -139,7 +139,11 @@ public:
|
|||
namespace {
|
||||
class LowerBroadcastToToLoops
|
||||
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
|
@ -257,7 +261,11 @@ namespace {
|
|||
class LowerLinalgOnTensorToLinalgOnMemref
|
||||
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
|
||||
LowerLinalgOnTensorToLinalgOnMemref> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
|
@ -351,7 +359,11 @@ GlobalCreator::GlobalCreator(ModuleOp module) {
|
|||
namespace {
|
||||
class LowerConstantTensorsToMemrefs
|
||||
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tcp::TCPDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
GlobalCreator globals(module);
|
||||
|
||||
|
|
|
@ -637,7 +637,11 @@ static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
|
|||
|
||||
namespace {
|
||||
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
|
|
|
@ -227,7 +227,11 @@ namespace {
|
|||
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||
// the npcomprt dialect.
|
||||
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
||||
void runOnOperation() {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, npcomprt::NpcomprtDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Before we lower anything, capture any needed metadata about the argument
|
||||
|
|
Loading…
Reference in New Issue