Bump llvm-project to 444822d77a7fea28aa49edf24533c987efa1b2ee

Fixes:
- renames StandardTypes -> BuiltinTypes
- std.extract_element -> tensor.extract
pull/140/head
Sean Silva 2020-12-11 14:43:38 -08:00
parent 251aa6e435
commit b2077738ca
37 changed files with 62 additions and 53 deletions

@ -1 +1 @@
Subproject commit 774f1d3ffd458d6cb82d5039758ef1cf6370957f Subproject commit 444822d77a7fea28aa49edf24533c987efa1b2ee

View File

@ -9,8 +9,8 @@
#include "debug.h" #include "debug.h"
#include "mlir_utils.h" #include "mlir_utils.h"
#include "mlir-c/StandardAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/StandardTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "npcomp-c/Types.h" #include "npcomp-c/Types.h"
#include "npcomp/Python/PybindUtils.h" #include "npcomp/Python/PybindUtils.h"

View File

@ -7,9 +7,9 @@
#include "func_builder.h" #include "func_builder.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
#include "npcomp-c/Types.h" #include "npcomp-c/Types.h"
using namespace torch_mlir; using namespace torch_mlir;
@ -32,7 +32,8 @@ KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc,
(void)this->context; // Preserve for future. (void)this->context; // Preserve for future.
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet( MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
toMlirStringRef("kernelName"), toMlirStringRef("kernelName"),
mlirStringAttrGet(context, kernelName.size(), kernelName.data())); mlirStringAttrGet(
context, mlirStringRefCreate(kernelName.data(), kernelName.size())));
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr); mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
addSchemaAttrs(); addSchemaAttrs();
} }
@ -59,7 +60,8 @@ void KernelCallBuilder::addSchemaAttrs() {
llvm::SmallVector<MlirAttribute, 4> args; llvm::SmallVector<MlirAttribute, 4> args;
for (auto &arg : schema.arguments()) { for (auto &arg : schema.arguments()) {
const std::string &typeStr = arg.type()->str(); const std::string &typeStr = arg.type()->str();
args.push_back(mlirStringAttrGet(context, typeStr.size(), typeStr.data())); args.push_back(mlirStringAttrGet(
context, mlirStringRefCreate(typeStr.data(), typeStr.size())));
} }
attrs.push_back(mlirNamedAttributeGet( attrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("sigArgTypes"), toMlirStringRef("sigArgTypes"),
@ -69,8 +71,8 @@ void KernelCallBuilder::addSchemaAttrs() {
llvm::SmallVector<MlirAttribute, 4> returns; llvm::SmallVector<MlirAttribute, 4> returns;
for (auto &ret : schema.returns()) { for (auto &ret : schema.returns()) {
const std::string &typeStr = ret.type()->str(); const std::string &typeStr = ret.type()->str();
returns.push_back( returns.push_back(mlirStringAttrGet(
mlirStringAttrGet(context, typeStr.size(), typeStr.data())); context, mlirStringRefCreate(typeStr.data(), typeStr.size())));
} }
attrs.push_back(mlirNamedAttributeGet( attrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("sigRetTypes"), toMlirStringRef("sigRetTypes"),
@ -215,7 +217,8 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
/*numResults=*/0, /*results=*/nullptr)))); /*numResults=*/0, /*results=*/nullptr))));
funcAttrs.push_back(mlirNamedAttributeGet( funcAttrs.push_back(mlirNamedAttributeGet(
toMlirStringRef("sym_name"), toMlirStringRef("sym_name"),
mlirStringAttrGet(context, name.size(), name.data()))); mlirStringAttrGet(context,
mlirStringRefCreate(name.data(), name.size()))));
MlirOperationState state = MlirOperationState state =
mlirOperationStateGet(toMlirStringRef("func"), location); mlirOperationStateGet(toMlirStringRef("func"), location);

View File

@ -9,9 +9,9 @@
#include "mlir_utils.h" #include "mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h" #include "mlir-c/Diagnostics.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
namespace py = pybind11; namespace py = pybind11;
using namespace torch_mlir; using namespace torch_mlir;

View File

@ -10,9 +10,9 @@
#include "graph_importer.h" #include "graph_importer.h"
#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Registration.h" #include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
#include "npcomp-c/Registration.h" #include "npcomp-c/Registration.h"
namespace py = pybind11; namespace py = pybind11;

View File

@ -17,8 +17,8 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"

View File

@ -10,11 +10,11 @@
#define NPCOMP_DIALECT_ATEN_IR_DIALECT_H #define NPCOMP_DIALECT_ATEN_IR_DIALECT_H
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -11,7 +11,7 @@
#include "npcomp/Dialect/ATen/IR/ATenDialect.h" #include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"

View File

@ -11,7 +11,7 @@
#include <string> #include <string>
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
namespace mlir { namespace mlir {

View File

@ -11,11 +11,11 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"

View File

@ -9,8 +9,8 @@
#ifndef NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H #ifndef NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H #define NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/StandardTypes.h"
#include "npcomp/Typing/Analysis/CPA/Interfaces.h" #include "npcomp/Typing/Analysis/CPA/Interfaces.h"
namespace mlir { namespace mlir {

View File

@ -10,10 +10,10 @@
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS_H #define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS_H
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Typing/Analysis/CPA/Interfaces.h" #include "npcomp/Typing/Analysis/CPA/Interfaces.h"

View File

@ -10,9 +10,9 @@
#define NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H #define NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -30,7 +30,7 @@ def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> {
Allocates a memref of the given shape. Allocates a memref of the given shape.
This op is a convenience for creating a bunch of This op is a convenience for creating a bunch of
extract_element ops + std.alloc. tensor.extract ops + std.alloc.
}]; }];
let arguments = (ins Shape_ExtentTensorType:$shape); let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref); let results = (outs AnyMemRef:$memref);

View File

@ -9,9 +9,9 @@
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H #ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H #define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES

View File

@ -9,9 +9,9 @@
#ifndef NPCOMP_DIALECT_TCF_IR_TCFOPS_H #ifndef NPCOMP_DIALECT_TCF_IR_TCFOPS_H
#define NPCOMP_DIALECT_TCF_IR_TCFOPS_H #define NPCOMP_DIALECT_TCF_IR_TCFOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/TCF/IR/TCFOps.h.inc" #include "npcomp/Dialect/TCF/IR/TCFOps.h.inc"

View File

@ -10,9 +10,9 @@
#define NPCOMP_DIALECT_TCP_IR_TCPOPS_H #define NPCOMP_DIALECT_TCP_IR_TCPOPS_H
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -9,9 +9,9 @@
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H #ifndef NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H #define NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES

View File

@ -13,9 +13,9 @@
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h" #include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"

View File

@ -10,7 +10,7 @@
#define NPCOMP_JITRUNTIME_JITMODULE_H #define NPCOMP_JITRUNTIME_JITMODULE_H
#include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/RefBackend/Runtime/UserAPI.h" #include "npcomp/RefBackend/Runtime/UserAPI.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"

View File

@ -29,6 +29,7 @@ def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> { def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
let summary = "Lower AllocMemRefOp's"; let summary = "Lower AllocMemRefOp's";
let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()"; let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()";
let dependentDialects = ["tensor::TensorDialect"];
} }
def LowerToLLVM : Pass<"refback-lower-to-llvm", "ModuleOp"> { def LowerToLLVM : Pass<"refback-lower-to-llvm", "ModuleOp"> {

View File

@ -9,7 +9,7 @@
#ifndef NPCOMP_TYPING_SUPPORT_CPA_IR_HELPERS_H #ifndef NPCOMP_TYPING_SUPPORT_CPA_IR_HELPERS_H
#define NPCOMP_TYPING_SUPPORT_CPA_IR_HELPERS_H #define NPCOMP_TYPING_SUPPORT_CPA_IR_HELPERS_H
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Typing/Analysis/CPA/Types.h" #include "npcomp/Typing/Analysis/CPA/Types.h"
namespace mlir { namespace mlir {

View File

@ -9,7 +9,7 @@
#include "npcomp-c/Types.h" #include "npcomp-c/Types.h"
#include "mlir/CAPI/IR.h" #include "mlir/CAPI/IR.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"

View File

@ -11,7 +11,7 @@
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include <iostream> #include <iostream>

View File

@ -23,8 +23,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Builders.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -14,7 +14,7 @@
#include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"

View File

@ -8,7 +8,7 @@
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"

View File

@ -10,7 +10,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Numpy/Transforms/Passes.h" #include "npcomp/Dialect/Numpy/Transforms/Passes.h"

View File

@ -8,7 +8,7 @@
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"

View File

@ -10,7 +10,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/TCF/IR/TCFDialect.h" #include "npcomp/Dialect/TCF/IR/TCFDialect.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h" #include "npcomp/Dialect/TCF/IR/TCFOps.h"
#include "npcomp/Dialect/TCF/Transforms/Passes.h" #include "npcomp/Dialect/TCF/Transforms/Passes.h"

View File

@ -11,8 +11,9 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h" #include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
@ -85,7 +86,7 @@ public:
SmallVector<Value, 6> outputExtents; SmallVector<Value, 6> outputExtents;
for (int i = 0, e = resultType.getRank(); i < e; i++) { for (int i = 0, e = resultType.getRank(); i < e; i++) {
Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i); Value dimIndex = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
Value outputExtent = rewriter.create<ExtractElementOp>( Value outputExtent = rewriter.create<tensor::ExtractOp>(
op.getLoc(), resultShape, ValueRange({dimIndex})); op.getLoc(), resultShape, ValueRange({dimIndex}));
outputExtents.push_back(outputExtent); outputExtents.push_back(outputExtent);
} }
@ -188,6 +189,7 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
target.addLegalDialect<linalg::LinalgDialect>(); target.addLegalDialect<linalg::LinalgDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>(); target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<tensor::TensorDialect>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) if (failed(applyPartialConversion(func, target, std::move(patterns))))
return signalPassFailure(); return signalPassFailure();

View File

@ -12,10 +12,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h" #include "llvm/Support/SourceMgr.h"

View File

@ -10,7 +10,7 @@
#include "npcomp/RefBackend/RefBackend.h" #include "npcomp/RefBackend/RefBackend.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -9,6 +9,7 @@
#ifndef REFBACKEND_PASSDETAIL_H #ifndef REFBACKEND_PASSDETAIL_H
#define REFBACKEND_PASSDETAIL_H #define REFBACKEND_PASSDETAIL_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {

View File

@ -34,6 +34,8 @@
#include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -97,7 +99,7 @@ public:
for (int i = 0, e = memrefType.getRank(); i < e; i++) { for (int i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i)) { if (memrefType.isDynamicDim(i)) {
auto ci = rewriter.create<ConstantIndexOp>(op.getLoc(), i); auto ci = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
auto extent = rewriter.create<ExtractElementOp>(op.getLoc(), shape, auto extent = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape,
ValueRange({ci})); ValueRange({ci}));
dynamicExtents.push_back(extent); dynamicExtents.push_back(extent);
} }
@ -119,7 +121,7 @@ class LowerAllocMemRefOps
patterns.insert<LowerAllocMemRefOp>(context); patterns.insert<LowerAllocMemRefOp>(context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<refback::AllocMemRefOp>(); target.addIllegalOp<refback::AllocMemRefOp>();
target.addLegalOp<ExtractElementOp>(); target.addLegalOp<tensor::ExtractOp>();
target.addLegalOp<AllocOp>(); target.addLegalOp<AllocOp>();
target.addLegalOp<ConstantOp>(); target.addLegalOp<ConstantOp>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) { if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
@ -247,6 +249,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
pm.addNestedPass<FuncOp>(createSCFBufferizePass()); pm.addNestedPass<FuncOp>(createSCFBufferizePass());
pm.addNestedPass<FuncOp>(createLinalgBufferizePass()); pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(createStdBufferizePass()); pm.addNestedPass<FuncOp>(createStdBufferizePass());
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
pm.addPass(createFuncBufferizePass()); pm.addPass(createFuncBufferizePass());
pm.addNestedPass<FuncOp>(createFinalizingBufferizePass()); pm.addNestedPass<FuncOp>(createFinalizingBufferizePass());

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @basic // CHECK-LABEL: func @basic
func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> { func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
// CHECK: %[[I:.*]] = constant 0 : index // CHECK: %[[I:.*]] = constant 0 : index
// CHECK: %[[E:.*]] = extract_element %arg0[%[[I]]] // CHECK: %[[E:.*]] = tensor.extract %arg0[%[[I]]]
// CHECK: alloc(%[[E]]) // CHECK: alloc(%[[E]])
%0 = refback.alloc_memref %arg0 : memref<?xf32> %0 = refback.alloc_memref %arg0 : memref<?xf32>
return %0 : memref<?xf32> return %0 : memref<?xf32>
@ -12,7 +12,7 @@ func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
// ----- // -----
// CHECK-LABEL: func @all_static // CHECK-LABEL: func @all_static
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> { func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
// CHECK-NOT: extract_element // CHECK-NOT: tensor.extract
// CHECK: alloc() // CHECK: alloc()
%0 = refback.alloc_memref %arg0 : memref<3x4x5xf32> %0 = refback.alloc_memref %arg0 : memref<3x4x5xf32>
return %0 : memref<3x4x5xf32> return %0 : memref<3x4x5xf32>
@ -22,9 +22,9 @@ func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
// CHECK-LABEL: func @some_static // CHECK-LABEL: func @some_static
func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> { func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
// CHECK-DAG: %[[I1:.*]] = constant 1 : index // CHECK-DAG: %[[I1:.*]] = constant 1 : index
// CHECK-DAG: %[[E1:.*]] = extract_element %arg0[%[[I1]]] // CHECK-DAG: %[[E1:.*]] = tensor.extract %arg0[%[[I1]]]
// CHECK-DAG: %[[I3:.*]] = constant 3 : index // CHECK-DAG: %[[I3:.*]] = constant 3 : index
// CHECK-DAG: %[[E3:.*]] = extract_element %arg0[%[[I3]]] // CHECK-DAG: %[[E3:.*]] = tensor.extract %arg0[%[[I3]]]
// CHECK: alloc(%[[E1]], %[[E3]]) // CHECK: alloc(%[[E1]], %[[E3]])
%0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32> %0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
return %0 : memref<3x?x5x?x7xf32> return %0 : memref<3x?x5x?x7xf32>

View File

@ -14,7 +14,7 @@ func @pow2(%arg0: tensor<f32>) -> tensor<f32> {
// TODO: Allow passing plain integers/floats (not tensors) at // TODO: Allow passing plain integers/floats (not tensors) at
// calling convention boundaries. // calling convention boundaries.
%num_iters_float = extract_element %arg0[] : tensor<f32> %num_iters_float = tensor.extract %arg0[] : tensor<f32>
%num_iters_i32 = fptosi %num_iters_float : f32 to i32 %num_iters_i32 = fptosi %num_iters_float : f32 to i32
%num_iters = index_cast %num_iters_i32 : i32 to index %num_iters = index_cast %num_iters_i32 : i32 to index
@ -26,4 +26,3 @@ func @pow2(%arg0: tensor<f32>) -> tensor<f32> {
} }
return %ret : tensor<f32> return %ret : tensor<f32>
} }