mirror of https://github.com/llvm/torch-mlir
Register dialects in E2E passes
parent
a2fb68059f
commit
843448cde9
|
@ -183,7 +183,11 @@ public:
|
||||||
namespace {
|
namespace {
|
||||||
class ResolveTensorLoadStoreOps
|
class ResolveTensorLoadStoreOps
|
||||||
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<linalg::LinalgDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
@ -291,7 +295,11 @@ public:
|
||||||
namespace {
|
namespace {
|
||||||
class LowerAllocMemRefOps
|
class LowerAllocMemRefOps
|
||||||
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
|
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
@ -219,7 +220,11 @@ public:
|
||||||
// step.
|
// step.
|
||||||
namespace {
|
namespace {
|
||||||
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<npcomprt::NpcomprtDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,11 @@ public:
|
||||||
namespace {
|
namespace {
|
||||||
class LowerBroadcastToToLoops
|
class LowerBroadcastToToLoops
|
||||||
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
: public LowerBroadcastToToLoopsBase<LowerBroadcastToToLoops> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
@ -257,7 +261,11 @@ namespace {
|
||||||
class LowerLinalgOnTensorToLinalgOnMemref
|
class LowerLinalgOnTensorToLinalgOnMemref
|
||||||
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
|
: public LowerLinalgOnTensorToLinalgOnMemrefBase<
|
||||||
LowerLinalgOnTensorToLinalgOnMemref> {
|
LowerLinalgOnTensorToLinalgOnMemref> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto func = getOperation();
|
auto func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
@ -351,7 +359,11 @@ GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||||
namespace {
|
namespace {
|
||||||
class LowerConstantTensorsToMemrefs
|
class LowerConstantTensorsToMemrefs
|
||||||
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<tcp::TCPDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
GlobalCreator globals(module);
|
GlobalCreator globals(module);
|
||||||
|
|
||||||
|
|
|
@ -637,7 +637,11 @@ static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<LLVM::LLVMDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,11 @@ namespace {
|
||||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||||
// the npcomprt dialect.
|
// the npcomprt dialect.
|
||||||
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
||||||
void runOnOperation() {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<linalg::LinalgDialect, npcomprt::NpcomprtDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
|
|
||||||
// Before we lower anything, capture any needed metadata about the argument
|
// Before we lower anything, capture any needed metadata about the argument
|
||||||
|
|
Loading…
Reference in New Issue