mirror of https://github.com/llvm/torch-mlir
Format sources.
parent
de38caa547
commit
fc4f374345
|
@ -14,7 +14,7 @@
|
|||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
|
||||
} // namespace aten
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
|
|
@ -11,9 +11,9 @@
|
|||
|
||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "aten-op-stats"
|
||||
|
||||
|
@ -25,7 +25,8 @@ namespace NPCOMP {
|
|||
namespace aten {
|
||||
|
||||
// Return the op statistics for conv2d-like operations.
|
||||
template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
|
||||
template <class T>
|
||||
std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
|
@ -66,11 +67,13 @@ template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uin
|
|||
}
|
||||
|
||||
// Return the op statistics for conv2dBackward-like operations.
|
||||
template<typename T>
|
||||
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t groups) {
|
||||
template <typename T>
|
||||
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op,
|
||||
uint64_t groups) {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
TensorType dx_out_resultTy = op.getResult(0).getType().template cast<TensorType>();
|
||||
TensorType dx_out_resultTy =
|
||||
op.getResult(0).getType().template cast<TensorType>();
|
||||
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
|
||||
|
||||
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
|
||||
|
@ -79,7 +82,7 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
|
|||
uint64_t kernel_width = weightTy.getShape()[2];
|
||||
uint64_t kernel_height = weightTy.getShape()[3];
|
||||
|
||||
uint64_t MACs_per_loss =
|
||||
uint64_t MACs_per_loss =
|
||||
(loss_in_depth / groups) * kernel_height * kernel_width;
|
||||
|
||||
uint64_t total_MACs = dx_out_volume * MACs_per_loss;
|
||||
|
@ -119,7 +122,6 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
|
||||
// Return a model of the number of bytes needed to represent the operand of
|
||||
// the given convolution-like operation with the given index. The shape is
|
||||
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
|
||||
|
@ -222,8 +224,7 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
|
|||
}
|
||||
|
||||
// Return the op statistics for matrixmultiply-like operations.
|
||||
template<typename T>
|
||||
std::map<std::string, uint64_t> getMMOpStatistics(T op) {
|
||||
template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace aten {
|
|||
// #define GEN_PASS_CLASSES
|
||||
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
|
||||
|
||||
void registerATenPasses();
|
||||
void registerATenPasses();
|
||||
} // namespace aten
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -44,33 +44,36 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
|
|||
);
|
||||
});
|
||||
|
||||
m.def("build_flow_transform_pass_pipeline",
|
||||
[](PyPassManager &pm) {
|
||||
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
|
||||
pm.passManager);
|
||||
},
|
||||
py::arg("pm"),
|
||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||
m.def("build_hal_transform_pass_pipeline",
|
||||
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
|
||||
mlir::iree_compiler::IREE::HAL::TargetOptions options;
|
||||
if (targetBackends.empty()) {
|
||||
options.targets =
|
||||
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
|
||||
} else {
|
||||
options.targets = std::move(targetBackends);
|
||||
}
|
||||
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(
|
||||
pm.passManager, options);
|
||||
},
|
||||
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
|
||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||
m.def("build_vm_transform_pass_pipeline",
|
||||
[](PyPassManager &pm) {
|
||||
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
|
||||
pm.passManager);
|
||||
},
|
||||
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
|
||||
m.def(
|
||||
"build_flow_transform_pass_pipeline",
|
||||
[](PyPassManager &pm) {
|
||||
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
|
||||
pm.passManager);
|
||||
},
|
||||
py::arg("pm"),
|
||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||
m.def(
|
||||
"build_hal_transform_pass_pipeline",
|
||||
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
|
||||
mlir::iree_compiler::IREE::HAL::TargetOptions options;
|
||||
if (targetBackends.empty()) {
|
||||
options.targets =
|
||||
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
|
||||
} else {
|
||||
options.targets = std::move(targetBackends);
|
||||
}
|
||||
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(pm.passManager,
|
||||
options);
|
||||
},
|
||||
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
|
||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||
m.def(
|
||||
"build_vm_transform_pass_pipeline",
|
||||
[](PyPassManager &pm) {
|
||||
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
|
||||
pm.passManager);
|
||||
},
|
||||
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
|
||||
m.def("translate_to_vm_bytecode", [](PyModuleOp &module) {
|
||||
// TODO: Make the options parameterizable.
|
||||
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options;
|
||||
|
|
|
@ -89,38 +89,39 @@ void npcomp::python::defineBackendRefJitModule(py::module m) {
|
|||
JITModule::buildBackendCompilationPipeline(pm.passManager);
|
||||
});
|
||||
py::class_<JITModule>(m, "JITModule")
|
||||
.def_static("from_compiled_module",
|
||||
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
||||
-> std::unique_ptr<JITModule> {
|
||||
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
||||
pySharedLibs.end());
|
||||
auto jitModule =
|
||||
checkError(JITModule::fromCompiledModule(
|
||||
module.moduleOp, sharedLibs),
|
||||
"error creating JITModule: ");
|
||||
return jitModule;
|
||||
},
|
||||
py::arg("module"), py::arg("shared_libs"))
|
||||
.def("invoke",
|
||||
[](JITModule &self, std::string functionName,
|
||||
std::vector<py::buffer> inputs) {
|
||||
// Prepare inputs.
|
||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
||||
inputTensors.reserve(inputs.size());
|
||||
for (py::buffer &inputBuffer : inputs) {
|
||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
||||
}
|
||||
.def_static(
|
||||
"from_compiled_module",
|
||||
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
||||
-> std::unique_ptr<JITModule> {
|
||||
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
||||
pySharedLibs.end());
|
||||
auto jitModule = checkError(
|
||||
JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
|
||||
"error creating JITModule: ");
|
||||
return jitModule;
|
||||
},
|
||||
py::arg("module"), py::arg("shared_libs"))
|
||||
.def(
|
||||
"invoke",
|
||||
[](JITModule &self, std::string functionName,
|
||||
std::vector<py::buffer> inputs) {
|
||||
// Prepare inputs.
|
||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
||||
inputTensors.reserve(inputs.size());
|
||||
for (py::buffer &inputBuffer : inputs) {
|
||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
||||
}
|
||||
|
||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
||||
"error invoking JIT function: ");
|
||||
std::vector<py::array> outputArrays;
|
||||
outputArrays.reserve(outputs.size());
|
||||
for (Ref<Tensor> &outputTensor : outputs) {
|
||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
||||
}
|
||||
return outputArrays;
|
||||
},
|
||||
py::arg("function_name"), py::arg("inputs"));
|
||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
||||
"error invoking JIT function: ");
|
||||
std::vector<py::array> outputArrays;
|
||||
outputArrays.reserve(outputs.size());
|
||||
for (Ref<Tensor> &outputTensor : outputs) {
|
||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
||||
}
|
||||
return outputArrays;
|
||||
},
|
||||
py::arg("function_name"), py::arg("inputs"));
|
||||
|
||||
// A Ref<Tensor> needs to be bound because we use it as a base for the
|
||||
// ndarray (the array retains a reference to it). Users should not encounter
|
||||
|
|
|
@ -55,7 +55,7 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// add
|
||||
// add
|
||||
std::map<std::string, uint64_t> AddOp::getStatistics() {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
@ -222,7 +222,8 @@ uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
|
|||
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
|
||||
return getConv2dBackwardStatistics(*this, 1);
|
||||
}
|
||||
std::map<std::string, uint64_t> ConvolutionBackwardOverrideableOp::getStatistics() {
|
||||
std::map<std::string, uint64_t>
|
||||
ConvolutionBackwardOverrideableOp::getStatistics() {
|
||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
||||
uint64_t groups = ia.getValue().getZExtValue();
|
||||
|
@ -463,7 +464,7 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
|
|||
// getMMOpStatistics(*this);
|
||||
// }
|
||||
std::map<std::string, uint64_t> MmOp::getStatistics() {
|
||||
return getMMOpStatistics(*this );
|
||||
return getMMOpStatistics(*this);
|
||||
}
|
||||
|
||||
// mul
|
||||
|
|
|
@ -61,7 +61,8 @@ namespace {
|
|||
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
||||
if (val.getType() == destTy)
|
||||
return val;
|
||||
return builder.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
|
||||
return builder
|
||||
.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -69,11 +70,11 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
|||
/// unknown shape.
|
||||
static MemRefType getShapeErasedMemRefType(MemRefType type) {
|
||||
std::vector<int64_t> shape = type.getShape();
|
||||
for(int i = 0; i < shape.size(); i++) {
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
shape[i] = -1;
|
||||
}
|
||||
return MemRefType::get(shape, type.getElementType(),
|
||||
type.getAffineMaps(), type.getMemorySpace());
|
||||
return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
|
||||
type.getMemorySpace());
|
||||
}
|
||||
|
||||
/// Create a type cast to memref
|
||||
|
@ -82,14 +83,12 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
|
|||
|
||||
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
|
||||
MemRefType newType = getShapeErasedMemRefType(memrefTy);
|
||||
return builder.create<MemRefCastOp>(val.getLoc(),
|
||||
val, newType)
|
||||
.getResult();
|
||||
return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
|
||||
}
|
||||
if (auto tensorTy = type.dyn_cast<TensorType>()) {
|
||||
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
|
||||
tensorTy.getElementType(), {}, 0);
|
||||
return typeCast(builder, val, memRefType);
|
||||
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
|
||||
tensorTy.getElementType(), {}, 0);
|
||||
return typeCast(builder, val, memRefType);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
@ -186,7 +185,7 @@ static std::string getSimplyMangledFuncName(std::string prefix,
|
|||
ret = ret + sep + getSimplyMangledType(t);
|
||||
for (const Type t : operTy) {
|
||||
std::string s = getSimplyMangledType(t);
|
||||
if(s.size() > 0)
|
||||
if (s.size() > 0)
|
||||
ret = ret + sep + getSimplyMangledType(t);
|
||||
}
|
||||
ret += "_out";
|
||||
|
@ -194,25 +193,22 @@ static std::string getSimplyMangledFuncName(std::string prefix,
|
|||
return ret;
|
||||
}
|
||||
static std::string getSimplyMangledFuncName(std::string prefix,
|
||||
FunctionType fnTy) {
|
||||
FunctionType fnTy) {
|
||||
|
||||
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
|
||||
}
|
||||
|
||||
std::string getMangledFuncName(std::string prefix,
|
||||
FunctionType fnTy) {
|
||||
std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
|
||||
return getSimplyMangledFuncName(prefix, fnTy);
|
||||
}
|
||||
|
||||
std::string getMangledFuncName(std::string prefix,
|
||||
ArrayRef<Type> opTys,
|
||||
std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
|
||||
ArrayRef<Type> retTys) {
|
||||
return getSimplyMangledFuncName(prefix, opTys, retTys);
|
||||
}
|
||||
|
||||
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> retTys) {
|
||||
ArrayRef<Value> operands, ArrayRef<Type> retTys) {
|
||||
Builder builder(module);
|
||||
|
||||
SmallVector<Type, 8> tys;
|
||||
|
@ -242,8 +238,8 @@ static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
|
|||
class AddOpConversion_affine : public ConversionPattern {
|
||||
public:
|
||||
explicit AddOpConversion_affine(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -310,78 +306,72 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
// Replace the given operation with a call to the given function.
|
||||
// The function is assumed to accept memrefs and scalar types and return
|
||||
// Memrefs. Here the result types are converted back to the result types of op,
|
||||
// but operands are NOT converted. This allows non-standard mappings from
|
||||
// operand types to function types.
|
||||
LogicalResult
|
||||
rewriteWithVoidFunctionCallExplicit(Operation *op,
|
||||
ArrayRef<Value> callops,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
std::string functionName) {
|
||||
LogicalResult rewriteWithVoidFunctionCallExplicit(
|
||||
Operation *op, ArrayRef<Value> callops, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter, std::string functionName) {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
edsc::ScopedContext scope(rewriter, loc);
|
||||
auto loc = op->getLoc();
|
||||
edsc::ScopedContext scope(rewriter, loc);
|
||||
|
||||
// The original operation types.
|
||||
SmallVector<Type, 8> opTys;
|
||||
// Shape erased versions of the original operation types.
|
||||
SmallVector<Type, 8> erasedOpTys;
|
||||
for (const Value &o: callops) {
|
||||
Type t = o.getType();
|
||||
opTys.push_back(t);
|
||||
if (t.isa<MemRefType>())
|
||||
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
|
||||
else
|
||||
erasedOpTys.push_back(t);
|
||||
// The original operation types.
|
||||
SmallVector<Type, 8> opTys;
|
||||
// Shape erased versions of the original operation types.
|
||||
SmallVector<Type, 8> erasedOpTys;
|
||||
for (const Value &o : callops) {
|
||||
Type t = o.getType();
|
||||
opTys.push_back(t);
|
||||
if (t.isa<MemRefType>())
|
||||
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
|
||||
else
|
||||
erasedOpTys.push_back(t);
|
||||
}
|
||||
|
||||
std::vector<Value> newOps = callops;
|
||||
SmallVector<Value, 8> newResults;
|
||||
|
||||
// Result types of the original operation, converted to memrefs.
|
||||
SmallVector<Type, 8> retTys;
|
||||
// Erased version of the return type. This is the return types of the
|
||||
// generated function call.
|
||||
SmallVector<Type, 8> erasedRetTys;
|
||||
for (const auto &o : op->getResults()) {
|
||||
Type t = o.getType();
|
||||
if (t.isa<TensorType>()) {
|
||||
TensorType tensorResultTy = t.cast<TensorType>();
|
||||
MemRefType memRefResultTy = mlir::MemRefType::get(
|
||||
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
|
||||
MemRefType erasedMemRefResultTy =
|
||||
getShapeErasedMemRefType(memRefResultTy);
|
||||
retTys.push_back(memRefResultTy);
|
||||
|
||||
// assume memRefResultTy has known shape, so we don't need any
|
||||
// dynamic dimensions for the alloc.
|
||||
assert(memRefResultTy.hasStaticShape());
|
||||
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
|
||||
Value castVal = memRefTypeCast(rewriter, allocVal);
|
||||
newOps.push_back(castVal);
|
||||
newResults.push_back(allocVal);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Value> newOps = callops;
|
||||
SmallVector<Value, 8> newResults;
|
||||
SmallVector<Type, 8> empty;
|
||||
std::string mangledFunctionName =
|
||||
getMangledFuncName(functionName, opTys, retTys);
|
||||
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
||||
mangledFunctionName, newOps, empty);
|
||||
|
||||
// Result types of the original operation, converted to memrefs.
|
||||
SmallVector<Type, 8> retTys;
|
||||
// Erased version of the return type. This is the return types of the
|
||||
// generated function call.
|
||||
SmallVector<Type, 8> erasedRetTys;
|
||||
for (const auto &o: op->getResults()) {
|
||||
Type t = o.getType();
|
||||
if (t.isa<TensorType>()) {
|
||||
TensorType tensorResultTy = t.cast<TensorType>();
|
||||
MemRefType memRefResultTy =
|
||||
mlir::MemRefType::get(tensorResultTy.getShape(),
|
||||
tensorResultTy.getElementType(), {}, 0);
|
||||
MemRefType erasedMemRefResultTy = getShapeErasedMemRefType(memRefResultTy);
|
||||
retTys.push_back(memRefResultTy);
|
||||
auto new_call =
|
||||
callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
|
||||
|
||||
// assume memRefResultTy has known shape, so we don't need any
|
||||
// dynamic dimensions for the alloc.
|
||||
assert(memRefResultTy.hasStaticShape());
|
||||
Value allocVal = rewriter.create<AllocOp>(op->getLoc(),
|
||||
memRefResultTy);
|
||||
Value castVal = memRefTypeCast(rewriter, allocVal);
|
||||
newOps.push_back(castVal);
|
||||
newResults.push_back(allocVal);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type, 8> empty;
|
||||
std::string mangledFunctionName = getMangledFuncName(functionName, opTys, retTys);
|
||||
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
||||
mangledFunctionName,
|
||||
newOps,
|
||||
empty);
|
||||
|
||||
auto new_call = callOperation(empty,
|
||||
rewriter.getSymbolRefAttr(funcOp), newOps);
|
||||
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Replace the given operation with a call to the given function.
|
||||
|
@ -389,54 +379,53 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
|
|||
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
|
||||
// appropriately. The called function passes results of the original function
|
||||
// as memref arguments at the end of the original set of operands.
|
||||
LogicalResult
|
||||
rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
std::string functionName) {
|
||||
auto loc = op->getLoc();
|
||||
edsc::ScopedContext scope(rewriter, loc);
|
||||
LogicalResult rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
std::string functionName) {
|
||||
auto loc = op->getLoc();
|
||||
edsc::ScopedContext scope(rewriter, loc);
|
||||
|
||||
// Convert the arguments to the original call.
|
||||
SmallVector<Value, 8> callops;
|
||||
for (auto &o: operands) {
|
||||
Type t = o.getType();
|
||||
if (t.isa<MemRefType>()) {
|
||||
// Cast it to some memref type that we accept
|
||||
callops.push_back(memRefTypeCast(rewriter, o));
|
||||
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
|
||||
callops.push_back(o);
|
||||
} else if (t.isa<ATenListType>()) {
|
||||
// FIXME: lots of assumptions here.
|
||||
auto unpack = [](auto &op, auto &v) -> void {
|
||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
|
||||
DenseElementsAttr a =
|
||||
co.template getAttrOfType<DenseElementsAttr>("value");
|
||||
for (auto i : a.getIntValues())
|
||||
v.push_back(i.getSExtValue());
|
||||
};
|
||||
std::vector<uint64_t> values;
|
||||
unpack(o, values);
|
||||
callops.push_back(constInt(values[0], 32));
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
// Convert the arguments to the original call.
|
||||
SmallVector<Value, 8> callops;
|
||||
for (auto &o : operands) {
|
||||
Type t = o.getType();
|
||||
if (t.isa<MemRefType>()) {
|
||||
// Cast it to some memref type that we accept
|
||||
callops.push_back(memRefTypeCast(rewriter, o));
|
||||
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
|
||||
callops.push_back(o);
|
||||
} else if (t.isa<ATenListType>()) {
|
||||
// FIXME: lots of assumptions here.
|
||||
auto unpack = [](auto &op, auto &v) -> void {
|
||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
|
||||
DenseElementsAttr a =
|
||||
co.template getAttrOfType<DenseElementsAttr>("value");
|
||||
for (auto i : a.getIntValues())
|
||||
v.push_back(i.getSExtValue());
|
||||
};
|
||||
std::vector<uint64_t> values;
|
||||
unpack(o, values);
|
||||
callops.push_back(constInt(values[0], 32));
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, functionName);
|
||||
}
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||
functionName);
|
||||
}
|
||||
|
||||
|
||||
/// Lower Add
|
||||
template<typename Op>
|
||||
template <typename Op>
|
||||
class ATenFunctionCallConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ATenFunctionCallConversion(MLIRContext *context)
|
||||
: ConversionPattern(Op::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(Op::getOperationName(), 1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, Op::getFunctionConversionName());
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
Op::getFunctionConversionName());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -444,8 +433,8 @@ public:
|
|||
class ConstantOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ConstantOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -459,14 +448,15 @@ public:
|
|||
Type t = result.getType();
|
||||
if (t.isa<IntegerType>()) {
|
||||
auto it = t.cast<IntegerType>();
|
||||
if(it.getWidth() > 1) {
|
||||
if (it.getWidth() > 1) {
|
||||
auto a = op->getAttrOfType<IntegerAttr>("value");
|
||||
SmallVector<Value, 8> newValues {rewriter.create<mlir::ConstantOp>(loc, a)};
|
||||
SmallVector<Value, 8> newValues{
|
||||
rewriter.create<mlir::ConstantOp>(loc, a)};
|
||||
rewriter.replaceOp(op, newValues);
|
||||
return success();
|
||||
} else {
|
||||
auto a = op->getAttrOfType<BoolAttr>("value");
|
||||
SmallVector<Value, 8> newValues {constInt(a.getValue(), it.getWidth())};
|
||||
SmallVector<Value, 8> newValues{constInt(a.getValue(), it.getWidth())};
|
||||
rewriter.replaceOp(op, newValues);
|
||||
return success();
|
||||
}
|
||||
|
@ -485,8 +475,8 @@ public:
|
|||
class AddOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit AddOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -513,8 +503,8 @@ public:
|
|||
class AsStridedOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit AsStridedOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -527,7 +517,8 @@ public:
|
|||
// construct the shape argument
|
||||
std::vector<Value> shape;
|
||||
std::vector<int64_t> result_shape;
|
||||
auto co0 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||
auto co0 =
|
||||
cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||
DenseElementsAttr a0 =
|
||||
co0.template getAttrOfType<DenseElementsAttr>("value");
|
||||
for (auto i : a0.getAttributeValues())
|
||||
|
@ -539,7 +530,8 @@ public:
|
|||
|
||||
// construct the stride argument
|
||||
std::vector<Value> stride;
|
||||
auto co1 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
|
||||
auto co1 =
|
||||
cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
|
||||
DenseElementsAttr a1 =
|
||||
co1.template getAttrOfType<DenseElementsAttr>("value");
|
||||
for (auto i : a1.getAttributeValues())
|
||||
|
@ -551,19 +543,21 @@ public:
|
|||
|
||||
APInt offset(32, 0);
|
||||
if (operands.size() > 3) {
|
||||
auto co2 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
|
||||
auto co2 =
|
||||
cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
|
||||
auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
|
||||
offset = ia2.getValue();
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> callops{xVal, shape[0],
|
||||
shape[1], shape[2],
|
||||
shape[3], stride[0],
|
||||
stride[1], stride[2],
|
||||
stride[3], constInt(offset.getSExtValue(), 32)};
|
||||
SmallVector<Value, 8> callops{
|
||||
xVal, shape[0],
|
||||
shape[1], shape[2],
|
||||
shape[3], stride[0],
|
||||
stride[1], stride[2],
|
||||
stride[3], constInt(offset.getSExtValue(), 32)};
|
||||
|
||||
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "as_strided");
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||
"as_strided");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -571,8 +565,8 @@ public:
|
|||
class BatchNormOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit BatchNormOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -585,8 +579,8 @@ public:
|
|||
class ConvolutionOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ConvolutionOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -614,8 +608,8 @@ public:
|
|||
class DivOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit DivOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -627,8 +621,8 @@ public:
|
|||
class LogSoftmaxOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit LogSoftmaxOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -657,8 +651,8 @@ public:
|
|||
class MaxPoolOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit MaxPoolOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -678,7 +672,8 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices");
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
"max_pool2d_with_indices");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -686,14 +681,15 @@ public:
|
|||
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::
|
||||
getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices_backward");
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
"max_pool2d_with_indices_backward");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -701,7 +697,8 @@ public:
|
|||
class MMOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit MMOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1, context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -714,8 +711,8 @@ public:
|
|||
class MulOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit MulOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1, context) {
|
||||
}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -728,8 +725,9 @@ public:
|
|||
class NativeBatchNormOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit NativeBatchNormOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(),
|
||||
1, context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -742,13 +740,15 @@ public:
|
|||
class NllLoss2dBackwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(),
|
||||
1, context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_backward");
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
"nll_loss2d_backward");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -756,13 +756,15 @@ public:
|
|||
class NllLoss2dForwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit NllLoss2dForwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(),
|
||||
1, context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_forward");
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
"nll_loss2d_forward");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -770,8 +772,9 @@ public:
|
|||
class NllLossBackwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit NllLossBackwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(),
|
||||
1, context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -784,13 +787,15 @@ public:
|
|||
class NllLossForwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit NllLossForwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward"); }
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward");
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower ReLU
|
||||
|
@ -811,13 +816,15 @@ public:
|
|||
class ThresholdBackwardOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit ThresholdBackwardOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(),
|
||||
1, context) {}
|
||||
: ConversionPattern(
|
||||
mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriteWithFunctionCall(op, operands, rewriter, "threshold_backward");
|
||||
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||
"threshold_backward");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -825,7 +832,8 @@ public:
|
|||
class TransposeOpConversion : public ConversionPattern {
|
||||
public:
|
||||
explicit TransposeOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1, context) {}
|
||||
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -849,21 +857,22 @@ public:
|
|||
|
||||
Value xVal = memRefTypeCast(rewriter, operands[0]);
|
||||
|
||||
// construct the shape argument
|
||||
// construct the shape argument
|
||||
SmallVector<Value, 8> shape;
|
||||
auto co = dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||
auto co =
|
||||
dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
|
||||
for (auto i : a.getAttributeValues())
|
||||
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
|
||||
|
||||
|
||||
// pad out the shape with -1 to make it 4d
|
||||
while (shape.size() < 4)
|
||||
shape.push_back(constInt(-1, 32));
|
||||
|
||||
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
|
||||
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "view");
|
||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||
"view");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -896,9 +905,8 @@ struct ATenLoweringPass
|
|||
|
||||
// c++ patterns
|
||||
acapPatterns.insert<
|
||||
ConstantOpConversion,
|
||||
AddOpConversion, ConvolutionOpConversion, ReLUOpConversion,
|
||||
TransposeOpConversion, BatchNormOpConversion,
|
||||
ConstantOpConversion, AddOpConversion, ConvolutionOpConversion,
|
||||
ReLUOpConversion, TransposeOpConversion, BatchNormOpConversion,
|
||||
NativeBatchNormOpConversion, MaxPoolOpConversion,
|
||||
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
|
||||
MulOpConversion, MMOpConversion, AsStridedOpConversion,
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/ATen/ATenToStd.h"
|
||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
|
|
@ -115,8 +115,8 @@ std::string LivenessReport::emitJSONReport() {
|
|||
for (auto v : vlist) {
|
||||
int64_t vol = getTensorVolume(v.getType());
|
||||
if (v.getDefiningOp()) {
|
||||
if (auto a = v.getDefiningOp()->getAttrOfType<StringAttr>(
|
||||
"layer_name")) {
|
||||
if (auto a =
|
||||
v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
|
||||
auto definingOp = v.getDefiningOp();
|
||||
auto ld = layerDetail.getInteger(a.getValue().str());
|
||||
if (ld)
|
||||
|
|
|
@ -32,28 +32,29 @@ public:
|
|||
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def("scf_if_op",
|
||||
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
||||
PyValue cond, bool withElseRegion) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
||||
withElseRegion);
|
||||
if (withElseRegion) {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint(),
|
||||
op.getElseBodyBuilder().saveInsertionPoint());
|
||||
} else {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint());
|
||||
}
|
||||
},
|
||||
py::arg("result_types"), py::arg("cond"),
|
||||
py::arg("with_else_region") = false);
|
||||
.def(
|
||||
"scf_if_op",
|
||||
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
||||
PyValue cond, bool withElseRegion) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
||||
withElseRegion);
|
||||
if (withElseRegion) {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint(),
|
||||
op.getElseBodyBuilder().saveInsertionPoint());
|
||||
} else {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint());
|
||||
}
|
||||
},
|
||||
py::arg("result_types"), py::arg("cond"),
|
||||
py::arg("with_else_region") = false);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -65,13 +65,14 @@ void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
|
|||
"front",
|
||||
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
|
||||
.def("__len__", [](ThisTy &self) { return self.list.size(); })
|
||||
.def("__iter__",
|
||||
[](ThisTy &self) {
|
||||
PyItemIterator begin(self.list.begin());
|
||||
PyItemIterator end(self.list.end());
|
||||
return py::make_iterator(begin, end);
|
||||
},
|
||||
py::keep_alive<0, 1>());
|
||||
.def(
|
||||
"__iter__",
|
||||
[](ThisTy &self) {
|
||||
PyItemIterator begin(self.list.begin());
|
||||
PyItemIterator end(self.list.end());
|
||||
return py::make_iterator(begin, end);
|
||||
},
|
||||
py::keep_alive<0, 1>());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -194,79 +195,81 @@ void PyDialectHelper::bind(py::module m) {
|
|||
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
||||
return self.context.shared_from_this();
|
||||
})
|
||||
.def("op",
|
||||
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||
std::vector<PyType> pyResultTypes,
|
||||
std::vector<PyValue> pyOperands,
|
||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
OperationName opName(opNameStr, opBuilder.getContext());
|
||||
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
MutableDictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = MutableDictionaryAttr(dictAttrs);
|
||||
}
|
||||
Operation *op =
|
||||
Operation::create(loc, opName, types, operands, attrList);
|
||||
opBuilder.insert(op);
|
||||
return op;
|
||||
},
|
||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||
.def("func_op",
|
||||
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
|
||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||
if (!functionType) {
|
||||
throw py::raiseValueError("Illegal function type");
|
||||
}
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
// TODO: Dedup attr creation from op().
|
||||
MutableDictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = MutableDictionaryAttr(dictAttrs);
|
||||
}
|
||||
FuncOp op =
|
||||
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||
/*attrs=*/attrList.getAttrs());
|
||||
if (createEntryBlock) {
|
||||
Block *entryBlock = new Block();
|
||||
entryBlock->addArguments(functionType.getInputs());
|
||||
op.getBody().push_back(entryBlock);
|
||||
opBuilder.setInsertionPointToStart(entryBlock);
|
||||
}
|
||||
return PyOperationRef(op);
|
||||
},
|
||||
py::arg("name"), py::arg("type"),
|
||||
py::arg("create_entry_block") = false,
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>(),
|
||||
R"(Creates a new `func` op, optionally creating an entry block.
|
||||
.def(
|
||||
"op",
|
||||
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||
std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
|
||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
OperationName opName(opNameStr, opBuilder.getContext());
|
||||
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
MutableDictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = MutableDictionaryAttr(dictAttrs);
|
||||
}
|
||||
Operation *op =
|
||||
Operation::create(loc, opName, types, operands, attrList);
|
||||
opBuilder.insert(op);
|
||||
return op;
|
||||
},
|
||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||
.def(
|
||||
"func_op",
|
||||
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
|
||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||
if (!functionType) {
|
||||
throw py::raiseValueError("Illegal function type");
|
||||
}
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
// TODO: Dedup attr creation from op().
|
||||
MutableDictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = MutableDictionaryAttr(dictAttrs);
|
||||
}
|
||||
FuncOp op =
|
||||
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||
/*attrs=*/attrList.getAttrs());
|
||||
if (createEntryBlock) {
|
||||
Block *entryBlock = new Block();
|
||||
entryBlock->addArguments(functionType.getInputs());
|
||||
op.getBody().push_back(entryBlock);
|
||||
opBuilder.setInsertionPointToStart(entryBlock);
|
||||
}
|
||||
return PyOperationRef(op);
|
||||
},
|
||||
py::arg("name"), py::arg("type"),
|
||||
py::arg("create_entry_block") = false,
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>(),
|
||||
R"(Creates a new `func` op, optionally creating an entry block.
|
||||
If an entry block is created, the builder will be positioned
|
||||
to its start.)")
|
||||
.def("select_op",
|
||||
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
||||
PyValue falseValue) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
return PyOperationRef(opBuilder.create<SelectOp>(
|
||||
loc, conditionValue, trueValue, falseValue));
|
||||
},
|
||||
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
||||
.def(
|
||||
"select_op",
|
||||
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
||||
PyValue falseValue) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
return PyOperationRef(opBuilder.create<SelectOp>(
|
||||
loc, conditionValue, trueValue, falseValue));
|
||||
},
|
||||
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
||||
.def("return_op",
|
||||
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
|
@ -288,11 +291,12 @@ void PyDialectHelper::bind(py::module m) {
|
|||
[](PyDialectHelper &self) -> PyType {
|
||||
return IndexType::get(self.getContext());
|
||||
})
|
||||
.def("integer_type",
|
||||
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||
return IntegerType::get(width, self.getContext());
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def(
|
||||
"integer_type",
|
||||
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||
return IntegerType::get(width, self.getContext());
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly("i1_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(1, self.getContext());
|
||||
|
@ -319,20 +323,21 @@ void PyDialectHelper::bind(py::module m) {
|
|||
return FloatType::get(StandardTypes::F64,
|
||||
self.getContext());
|
||||
})
|
||||
.def("tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
}
|
||||
if (shape) {
|
||||
return RankedTensorType::get(*shape, elementType.type);
|
||||
} else {
|
||||
return UnrankedTensorType::get(elementType.type);
|
||||
}
|
||||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def(
|
||||
"tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
}
|
||||
if (shape) {
|
||||
return RankedTensorType::get(*shape, elementType.type);
|
||||
} else {
|
||||
return UnrankedTensorType::get(elementType.type);
|
||||
}
|
||||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def("function_type",
|
||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) -> PyType {
|
||||
|
@ -367,21 +372,24 @@ void defineMlirIrModule(py::module m) {
|
|||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||
|
||||
// Globals.
|
||||
m.def("emit_error",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def("emit_warning",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def("emit_remark",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def(
|
||||
"emit_error",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def(
|
||||
"emit_warning",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def(
|
||||
"emit_remark",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
|
||||
// Python only types.
|
||||
PyDialectHelper::bind(m);
|
||||
|
@ -426,25 +434,27 @@ void PyContext::bind(py::module m) {
|
|||
return PyModuleOp(self.shared_from_this(), m);
|
||||
})
|
||||
.def("parse_asm", &PyContext::parseAsm)
|
||||
.def("new_builder",
|
||||
[](PyContext &self) {
|
||||
// Note: we collapse the Builder and OpBuilder into one because
|
||||
// there is little reason to expose the inheritance hierarchy to
|
||||
// Python.
|
||||
return PyOpBuilder(self);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def(
|
||||
"new_builder",
|
||||
[](PyContext &self) {
|
||||
// Note: we collapse the Builder and OpBuilder into one because
|
||||
// there is little reason to expose the inheritance hierarchy to
|
||||
// Python.
|
||||
return PyOpBuilder(self);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def("identifier",
|
||||
[](PyContext &self, std::string s) -> PyIdentifier {
|
||||
return Identifier::get(s, &self.context);
|
||||
})
|
||||
.def("file_line_col_loc_attr",
|
||||
[](PyContext &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) -> PyAttribute {
|
||||
return static_cast<LocationAttr>(FileLineColLoc::get(
|
||||
filename.identifier, line, column, &self.context));
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"))
|
||||
.def(
|
||||
"file_line_col_loc_attr",
|
||||
[](PyContext &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) -> PyAttribute {
|
||||
return static_cast<LocationAttr>(FileLineColLoc::get(
|
||||
filename.identifier, line, column, &self.context));
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"))
|
||||
// Salient functions from Builder.
|
||||
.def("parse_type",
|
||||
[](PyContext &self, const std::string &asmText) {
|
||||
|
@ -456,14 +466,15 @@ void PyContext::bind(py::module m) {
|
|||
}
|
||||
return PyType(t);
|
||||
})
|
||||
.def("integer_attr",
|
||||
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
|
||||
if (!type.type.isa<IntegerType>()) {
|
||||
throw py::raiseValueError("Expected IntegerType");
|
||||
}
|
||||
return IntegerAttr::get(type.type, value);
|
||||
},
|
||||
py::arg("type"), py::arg("value"))
|
||||
.def(
|
||||
"integer_attr",
|
||||
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
|
||||
if (!type.type.isa<IntegerType>()) {
|
||||
throw py::raiseValueError("Expected IntegerType");
|
||||
}
|
||||
return IntegerAttr::get(type.type, value);
|
||||
},
|
||||
py::arg("type"), py::arg("value"))
|
||||
.def("float_attr",
|
||||
[](PyContext &self, PyType type, double value) -> PyAttribute {
|
||||
if (!type.type.isa<FloatType>()) {
|
||||
|
@ -512,20 +523,20 @@ void PyContext::bind(py::module m) {
|
|||
}
|
||||
return ArrayAttr::get(attrs, &self.context);
|
||||
})
|
||||
.def("dense_elements_attr",
|
||||
[](PyContext &self, py::buffer array) -> PyAttribute {
|
||||
// Request a contiguous view.
|
||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||
Py_buffer *view = new Py_buffer();
|
||||
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
||||
delete view;
|
||||
throw py::error_already_set();
|
||||
}
|
||||
py::buffer_info array_info(view);
|
||||
return createDenseElementsAttrFromBuffer(&self.context,
|
||||
array_info);
|
||||
},
|
||||
py::arg("array"))
|
||||
.def(
|
||||
"dense_elements_attr",
|
||||
[](PyContext &self, py::buffer array) -> PyAttribute {
|
||||
// Request a contiguous view.
|
||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||
Py_buffer *view = new Py_buffer();
|
||||
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
||||
delete view;
|
||||
throw py::error_already_set();
|
||||
}
|
||||
py::buffer_info array_info(view);
|
||||
return createDenseElementsAttrFromBuffer(&self.context, array_info);
|
||||
},
|
||||
py::arg("array"))
|
||||
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
|
||||
return UnitAttr::get(&self.context);
|
||||
});
|
||||
|
@ -905,67 +916,74 @@ void PyBaseOpBuilder::bind(py::module m) {
|
|||
void PyOpBuilder::bind(py::module m) {
|
||||
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
|
||||
.def_property("current_loc",
|
||||
[](PyOpBuilder &self) -> PyAttribute {
|
||||
return static_cast<Attribute>(self.getCurrentLoc());
|
||||
},
|
||||
[](PyOpBuilder &self, PyAttribute attr) {
|
||||
auto loc_attr =
|
||||
attr.attr.dyn_cast_or_null<LocationAttr>();
|
||||
if (!loc_attr) {
|
||||
throw py::raiseValueError("Expected a LocationAttr");
|
||||
}
|
||||
self.setCurrentLoc(Location(loc_attr));
|
||||
})
|
||||
.def_property("insertion_point",
|
||||
[](PyOpBuilder &self) {
|
||||
return self.getBuilder(true).saveInsertionPoint();
|
||||
},
|
||||
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
||||
self.getBuilder(false).restoreInsertionPoint(ip);
|
||||
})
|
||||
.def("set_file_line_col",
|
||||
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) {
|
||||
Location loc = FileLineColLoc::get(filename.identifier, line,
|
||||
column, self.getContext());
|
||||
self.setCurrentLoc(loc);
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"),
|
||||
"Shortcut to set a FileLineCol current location")
|
||||
.def_property(
|
||||
"current_loc",
|
||||
[](PyOpBuilder &self) -> PyAttribute {
|
||||
return static_cast<Attribute>(self.getCurrentLoc());
|
||||
},
|
||||
[](PyOpBuilder &self, PyAttribute attr) {
|
||||
auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
|
||||
if (!loc_attr) {
|
||||
throw py::raiseValueError("Expected a LocationAttr");
|
||||
}
|
||||
self.setCurrentLoc(Location(loc_attr));
|
||||
})
|
||||
.def_property(
|
||||
"insertion_point",
|
||||
[](PyOpBuilder &self) {
|
||||
return self.getBuilder(true).saveInsertionPoint();
|
||||
},
|
||||
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
||||
self.getBuilder(false).restoreInsertionPoint(ip);
|
||||
})
|
||||
.def(
|
||||
"set_file_line_col",
|
||||
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) {
|
||||
Location loc = FileLineColLoc::get(filename.identifier, line,
|
||||
column, self.getContext());
|
||||
self.setCurrentLoc(loc);
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"),
|
||||
"Shortcut to set a FileLineCol current location")
|
||||
.def("clear_insertion_point",
|
||||
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
|
||||
.def("insert_op_before",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPoint(op);
|
||||
},
|
||||
"Sets the insertion point to just before the specified op.")
|
||||
.def("insert_op_after",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPointAfter(op);
|
||||
},
|
||||
"Sets the insertion point to just after the specified op.")
|
||||
.def("insert_block_start",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToStart(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the start of the block.")
|
||||
.def("insert_block_end",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToEnd(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the end of the block.")
|
||||
.def("insert_before_terminator",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
auto *terminator = block.block.getTerminator();
|
||||
if (!terminator) {
|
||||
throw py::raiseValueError("Block has no terminator");
|
||||
}
|
||||
self.builder.setInsertionPoint(terminator);
|
||||
},
|
||||
"Sets the insertion point to just before the block terminator.");
|
||||
.def(
|
||||
"insert_op_before",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPoint(op);
|
||||
},
|
||||
"Sets the insertion point to just before the specified op.")
|
||||
.def(
|
||||
"insert_op_after",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPointAfter(op);
|
||||
},
|
||||
"Sets the insertion point to just after the specified op.")
|
||||
.def(
|
||||
"insert_block_start",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToStart(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the start of the block.")
|
||||
.def(
|
||||
"insert_block_end",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToEnd(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the end of the block.")
|
||||
.def(
|
||||
"insert_before_terminator",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
auto *terminator = block.block.getTerminator();
|
||||
if (!terminator) {
|
||||
throw py::raiseValueError("Block has no terminator");
|
||||
}
|
||||
self.builder.setInsertionPoint(terminator);
|
||||
},
|
||||
"Sets the insertion point to just before the block terminator.");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -32,13 +32,14 @@ void PyPassManager::bind(py::module m) {
|
|||
py::class_<PyPassManager>(m, "PassManager")
|
||||
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
||||
py::arg("verifyModules") = true)
|
||||
.def("enableCrashReproducerGeneration",
|
||||
[](PyPassManager &self, std::string outputFile,
|
||||
bool genLocalReproducer) {
|
||||
self.passManager.enableCrashReproducerGeneration(
|
||||
outputFile, genLocalReproducer);
|
||||
},
|
||||
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||
.def(
|
||||
"enableCrashReproducerGeneration",
|
||||
[](PyPassManager &self, std::string outputFile,
|
||||
bool genLocalReproducer) {
|
||||
self.passManager.enableCrashReproducerGeneration(
|
||||
outputFile, genLocalReproducer);
|
||||
},
|
||||
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||
.def("__len__",
|
||||
[](PyPassManager &self) { return self.passManager.size(); })
|
||||
.def("__str__",
|
||||
|
|
|
@ -48,18 +48,19 @@ public:
|
|||
return Basicpy::NoneType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def("basicpy_SlotObject_type",
|
||||
[](BasicpyDialectHelper &self, std::string className,
|
||||
py::args pySlotTypes) -> PyType {
|
||||
SmallVector<Type, 4> slotTypes;
|
||||
for (auto pySlotType : pySlotTypes) {
|
||||
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||
}
|
||||
auto classNameAttr =
|
||||
StringAttr::get(className, self.getContext());
|
||||
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||
},
|
||||
py::arg("className"))
|
||||
.def(
|
||||
"basicpy_SlotObject_type",
|
||||
[](BasicpyDialectHelper &self, std::string className,
|
||||
py::args pySlotTypes) -> PyType {
|
||||
SmallVector<Type, 4> slotTypes;
|
||||
for (auto pySlotType : pySlotTypes) {
|
||||
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||
}
|
||||
auto classNameAttr =
|
||||
StringAttr::get(className, self.getContext());
|
||||
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||
},
|
||||
py::arg("className"))
|
||||
.def_property_readonly("basicpy_StrType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::StrType::get(
|
||||
|
|
Loading…
Reference in New Issue