mirror of https://github.com/llvm/torch-mlir
build: update llvm tag to 6f46ff37 (#1448)
Summary of changes: - Updated references to the Arith dialect (https://reviews.llvm.org/D134762) - Switched to prefixed accessors for MemRef dialect (https://reviews.llvm.org/D134995) - Fixed warnings about signed/unsigned comparisons, ignored return values, and unused variablespull/1214/head
parent
708fa346a6
commit
faa9a78e38
|
@ -7,8 +7,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||||
|
@ -134,7 +134,7 @@ struct TMTensorBufferizePass
|
||||||
bufferization::BufferizeTypeConverter typeConverter;
|
bufferization::BufferizeTypeConverter typeConverter;
|
||||||
|
|
||||||
// Mark all Standard operations legal.
|
// Mark all Standard operations legal.
|
||||||
target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
|
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,
|
||||||
memref::MemRefDialect, tensor::TensorDialect>();
|
memref::MemRefDialect, tensor::TensorDialect>();
|
||||||
|
|
||||||
// Mark all TMTensor operations illegal as long as they work on tensors.
|
// Mark all TMTensor operations illegal as long as they work on tensors.
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
@ -101,7 +101,7 @@ namespace {
|
||||||
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<linalg::LinalgDialect, func::FuncDialect,
|
registry.insert<linalg::LinalgDialect, func::FuncDialect,
|
||||||
mlir::arith::ArithmeticDialect, math::MathDialect,
|
mlir::arith::ArithDialect, math::MathDialect,
|
||||||
memref::MemRefDialect, scf::SCFDialect>();
|
memref::MemRefDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
set(LIBS
|
set(LIBS
|
||||||
MLIRArithmeticDialect
|
MLIRArithDialect
|
||||||
MLIRDialect
|
MLIRDialect
|
||||||
MLIRLinalgDialect
|
MLIRLinalgDialect
|
||||||
MLIRMemRefDialect
|
MLIRMemRefDialect
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/AsmState.h"
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -39,7 +39,7 @@ int main(int argc, char **argv) {
|
||||||
// Local dialects
|
// Local dialects
|
||||||
mlir::torch::TMTensor::TMTensorDialect,
|
mlir::torch::TMTensor::TMTensorDialect,
|
||||||
// Upstream dialects
|
// Upstream dialects
|
||||||
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
|
mlir::arith::ArithDialect, mlir::linalg::LinalgDialect,
|
||||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
||||||
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit bebc96956b76bdbc36f1d82a788c810e5b12e2c5
|
Subproject commit 6f46ff3765dcdc178b9cf52ebd8c03437806798a
|
|
@ -1 +1 @@
|
||||||
Subproject commit 7b0ecf7827e3fc07d2af90e147bcedc165bc78ac
|
Subproject commit 2f7c1454bbe4c4ad0ae1c86c5539ac58b6053b6a
|
|
@ -10,7 +10,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
@ -300,7 +300,7 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<cf::ControlFlowDialect>();
|
registry.insert<cf::ControlFlowDialect>();
|
||||||
registry.insert<math::MathDialect>();
|
registry.insert<math::MathDialect>();
|
||||||
|
@ -311,7 +311,7 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
||||||
arith::ArithmeticDialect, tensor::TensorDialect,
|
arith::ArithDialect, tensor::TensorDialect,
|
||||||
cf::ControlFlowDialect, math::MathDialect>();
|
cf::ControlFlowDialect, math::MathDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -252,11 +252,11 @@ public:
|
||||||
llvm::all_of(expandShape,
|
llvm::all_of(expandShape,
|
||||||
[](int64_t value) { return value == kUnknownSize; })) {
|
[](int64_t value) { return value == kUnknownSize; })) {
|
||||||
|
|
||||||
for (int i = 0; i < collapseShape.size(); i++) {
|
for (size_t i = 0; i < collapseShape.size(); i++) {
|
||||||
collapseIndices.push_back(i);
|
collapseIndices.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < expandShape.size(); i++) {
|
for (size_t i = 0; i < expandShape.size(); i++) {
|
||||||
expandIndices.push_back(i);
|
expandIndices.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,8 +290,8 @@ public:
|
||||||
op, "total number of elements mismatch in the expansion");
|
op, "total number of elements mismatch in the expansion");
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
|
static void solveDynamicSize(SmallVector<int64_t> &inputShape,
|
||||||
SmallVector<int64_t> &outputShape) {
|
SmallVector<int64_t> &outputShape) {
|
||||||
int64_t inputProduct = 1;
|
int64_t inputProduct = 1;
|
||||||
int64_t outputProduct = 1;
|
int64_t outputProduct = 1;
|
||||||
|
|
||||||
|
@ -316,7 +316,7 @@ public:
|
||||||
if (inputDynamicValues + outputDynamicValues == 1) {
|
if (inputDynamicValues + outputDynamicValues == 1) {
|
||||||
if (inputDynamicValues) {
|
if (inputDynamicValues) {
|
||||||
int64_t missingValue = outputProduct / inputProduct;
|
int64_t missingValue = outputProduct / inputProduct;
|
||||||
for (int i = 0; i < inputShape.size(); i++) {
|
for (size_t i = 0; i < inputShape.size(); i++) {
|
||||||
if (inputShape[i] == -1) {
|
if (inputShape[i] == -1) {
|
||||||
inputShape[i] = missingValue;
|
inputShape[i] = missingValue;
|
||||||
break;
|
break;
|
||||||
|
@ -324,7 +324,7 @@ public:
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int64_t missingValue = inputProduct / outputProduct;
|
int64_t missingValue = inputProduct / outputProduct;
|
||||||
for (int i = 0; i < outputShape.size(); i++) {
|
for (size_t i = 0; i < outputShape.size(); i++) {
|
||||||
if (outputShape[i] == -1) {
|
if (outputShape[i] == -1) {
|
||||||
outputShape[i] = missingValue;
|
outputShape[i] = missingValue;
|
||||||
break;
|
break;
|
||||||
|
@ -332,8 +332,6 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -625,9 +623,6 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t inputCount = inputAssociations.size();
|
|
||||||
int64_t outputCount = outputAssociations.size();
|
|
||||||
|
|
||||||
// Check if the shapes already match up to dynamic sizes. If so, we can just
|
// Check if the shapes already match up to dynamic sizes. If so, we can just
|
||||||
// cast as the result type because the previous loop sets up the necessary
|
// cast as the result type because the previous loop sets up the necessary
|
||||||
// dim checks in case of dynamic sizes.
|
// dim checks in case of dynamic sizes.
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
@ -43,7 +43,7 @@ public:
|
||||||
registry.insert<math::MathDialect>();
|
registry.insert<math::MathDialect>();
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
registry.insert<cf::ControlFlowDialect>();
|
registry.insert<cf::ControlFlowDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ public:
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||||
cf::ControlFlowDialect, math::MathDialect,
|
cf::ControlFlowDialect, math::MathDialect,
|
||||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
tensor::TensorDialect, arith::ArithDialect>();
|
||||||
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/utils/hlo_utils.h"
|
#include "mlir-hlo/utils/hlo_utils.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
@ -42,14 +42,14 @@ public:
|
||||||
registry.insert<chlo::ChloDialect>();
|
registry.insert<chlo::ChloDialect>();
|
||||||
registry.insert<mhlo::MhloDialect>();
|
registry.insert<mhlo::MhloDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
||||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
tensor::TensorDialect, arith::ArithDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "./MhloLegalizeUtils.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "./PopulatePatterns.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
@ -321,7 +321,7 @@ namespace {
|
||||||
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<scf::SCFDialect, arith::ArithmeticDialect>();
|
registry.insert<scf::SCFDialect, arith::ArithDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,7 +329,7 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
|
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
|
||||||
arith::ArithmeticDialect>();
|
arith::ArithDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -614,7 +614,7 @@ public:
|
||||||
registry.insert<linalg::LinalgDialect>();
|
registry.insert<linalg::LinalgDialect>();
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
registry.insert<TMTensorDialect>();
|
registry.insert<TMTensorDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
@ -623,7 +623,7 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||||
tensor::TensorDialect, arith::ArithmeticDialect,
|
tensor::TensorDialect, arith::ArithDialect,
|
||||||
Torch::TorchDialect, TMTensorDialect>();
|
Torch::TorchDialect, TMTensorDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
@ -3511,7 +3511,7 @@ public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<tosa::TosaDialect>();
|
registry.insert<tosa::TosaDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<arith::ArithmeticDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3519,7 +3519,7 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
||||||
arith::ArithmeticDialect>();
|
arith::ArithDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
|
|
@ -5,7 +5,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRArithmeticDialect
|
MLIRArithDialect
|
||||||
MLIRLinalgDialect
|
MLIRLinalgDialect
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
|
@ -8,13 +8,13 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
@ -71,8 +71,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
// Basic scalar operations.
|
// Basic scalar operations.
|
||||||
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
|
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
|
||||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);
|
||||||
isLegalScalarOp);
|
|
||||||
|
|
||||||
// Tensor operations should go through linalg and the tensor dialect.
|
// Tensor operations should go through linalg and the tensor dialect.
|
||||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -53,7 +53,7 @@ class VerifyMhloBackendContractPass
|
||||||
target.addLegalDialect<mhlo::MhloDialect>();
|
target.addLegalDialect<mhlo::MhloDialect>();
|
||||||
target.addLegalDialect<chlo::ChloDialect>();
|
target.addLegalDialect<chlo::ChloDialect>();
|
||||||
target.addLegalDialect<tensor::TensorDialect>();
|
target.addLegalDialect<tensor::TensorDialect>();
|
||||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
target.addLegalDialect<arith::ArithDialect>();
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
@ -304,7 +304,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<func::FuncDialect>();
|
target.addLegalDialect<func::FuncDialect>();
|
||||||
target.addLegalDialect<math::MathDialect>();
|
target.addLegalDialect<math::MathDialect>();
|
||||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
target.addLegalDialect<arith::ArithDialect>();
|
||||||
target.addIllegalOp<math::TanhOp>();
|
target.addIllegalOp<math::TanhOp>();
|
||||||
target.addIllegalOp<math::ErfOp>();
|
target.addIllegalOp<math::ErfOp>();
|
||||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||||
|
@ -352,7 +352,7 @@ class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
|
||||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
|
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Operation *linalgCopy = createLinalgCopyOp(
|
Operation *linalgCopy = createLinalgCopyOp(
|
||||||
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
|
rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget());
|
||||||
rewriter.replaceOp(copyOp, linalgCopy->getResults());
|
rewriter.replaceOp(copyOp, linalgCopy->getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue