Format sources.

pull/35/head
Stella Laurenzo 2020-08-27 14:47:49 -07:00
parent de38caa547
commit fc4f374345
13 changed files with 537 additions and 502 deletions

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc" #include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
} // namespace aten
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir
#endif #endif

View File

@ -11,9 +11,9 @@
#include "npcomp/Dialect/ATen/ATenDialect.h" #include "npcomp/Dialect/ATen/ATenDialect.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "aten-op-stats" #define DEBUG_TYPE "aten-op-stats"
@ -25,7 +25,8 @@ namespace NPCOMP {
namespace aten { namespace aten {
// Return the op statistics for conv2d-like operations. // Return the op statistics for conv2d-like operations.
template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) { template <class T>
std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
@ -66,11 +67,13 @@ template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uin
} }
// Return the op statistics for conv2dBackward-like operations. // Return the op statistics for conv2dBackward-like operations.
template<typename T> template <typename T>
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t groups) { std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op,
uint64_t groups) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
TensorType dx_out_resultTy = op.getResult(0).getType().template cast<TensorType>(); TensorType dx_out_resultTy =
op.getResult(0).getType().template cast<TensorType>();
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy); uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>(); TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
@ -79,7 +82,7 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
uint64_t kernel_width = weightTy.getShape()[2]; uint64_t kernel_width = weightTy.getShape()[2];
uint64_t kernel_height = weightTy.getShape()[3]; uint64_t kernel_height = weightTy.getShape()[3];
uint64_t MACs_per_loss = uint64_t MACs_per_loss =
(loss_in_depth / groups) * kernel_height * kernel_width; (loss_in_depth / groups) * kernel_height * kernel_width;
uint64_t total_MACs = dx_out_volume * MACs_per_loss; uint64_t total_MACs = dx_out_volume * MACs_per_loss;
@ -119,7 +122,6 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
return toReturn; return toReturn;
} }
// Return a model of the number of bytes needed to represent the operand of // Return a model of the number of bytes needed to represent the operand of
// the given convolution-like operation with the given index. The shape is // the given convolution-like operation with the given index. The shape is
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO: // assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
@ -222,8 +224,7 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
} }
// Return the op statistics for matrixmultiply-like operations. // Return the op statistics for matrixmultiply-like operations.
template<typename T> template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;

View File

@ -20,7 +20,7 @@ namespace aten {
// #define GEN_PASS_CLASSES // #define GEN_PASS_CLASSES
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc" // #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
void registerATenPasses(); void registerATenPasses();
} // namespace aten } // namespace aten
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -44,33 +44,36 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
); );
}); });
m.def("build_flow_transform_pass_pipeline", m.def(
[](PyPassManager &pm) { "build_flow_transform_pass_pipeline",
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline( [](PyPassManager &pm) {
pm.passManager); mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
}, pm.passManager);
py::arg("pm"), },
py::doc("Builds a pass pipeline for top-level Flow import")); py::arg("pm"),
m.def("build_hal_transform_pass_pipeline", py::doc("Builds a pass pipeline for top-level Flow import"));
[](PyPassManager &pm, std::vector<std::string> targetBackends) { m.def(
mlir::iree_compiler::IREE::HAL::TargetOptions options; "build_hal_transform_pass_pipeline",
if (targetBackends.empty()) { [](PyPassManager &pm, std::vector<std::string> targetBackends) {
options.targets = mlir::iree_compiler::IREE::HAL::TargetOptions options;
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends(); if (targetBackends.empty()) {
} else { options.targets =
options.targets = std::move(targetBackends); mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
} } else {
iree_compiler::IREE::HAL::buildHALTransformPassPipeline( options.targets = std::move(targetBackends);
pm.passManager, options); }
}, iree_compiler::IREE::HAL::buildHALTransformPassPipeline(pm.passManager,
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(), options);
py::doc("Builds a pass pipeline for top-level Flow import")); },
m.def("build_vm_transform_pass_pipeline", py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
[](PyPassManager &pm) { py::doc("Builds a pass pipeline for top-level Flow import"));
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline( m.def(
pm.passManager); "build_vm_transform_pass_pipeline",
}, [](PyPassManager &pm) {
py::arg("pm"), py::doc("Builds the VM transformation pipeline")); mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
pm.passManager);
},
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
m.def("translate_to_vm_bytecode", [](PyModuleOp &module) { m.def("translate_to_vm_bytecode", [](PyModuleOp &module) {
// TODO: Make the options parameterizable. // TODO: Make the options parameterizable.
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options; mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options;

View File

@ -89,38 +89,39 @@ void npcomp::python::defineBackendRefJitModule(py::module m) {
JITModule::buildBackendCompilationPipeline(pm.passManager); JITModule::buildBackendCompilationPipeline(pm.passManager);
}); });
py::class_<JITModule>(m, "JITModule") py::class_<JITModule>(m, "JITModule")
.def_static("from_compiled_module", .def_static(
[](PyModuleOp module, std::vector<std::string> pySharedLibs) "from_compiled_module",
-> std::unique_ptr<JITModule> { [](PyModuleOp module, std::vector<std::string> pySharedLibs)
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(), -> std::unique_ptr<JITModule> {
pySharedLibs.end()); SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
auto jitModule = pySharedLibs.end());
checkError(JITModule::fromCompiledModule( auto jitModule = checkError(
module.moduleOp, sharedLibs), JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
"error creating JITModule: "); "error creating JITModule: ");
return jitModule; return jitModule;
}, },
py::arg("module"), py::arg("shared_libs")) py::arg("module"), py::arg("shared_libs"))
.def("invoke", .def(
[](JITModule &self, std::string functionName, "invoke",
std::vector<py::buffer> inputs) { [](JITModule &self, std::string functionName,
// Prepare inputs. std::vector<py::buffer> inputs) {
llvm::SmallVector<Ref<Tensor>, 4> inputTensors; // Prepare inputs.
inputTensors.reserve(inputs.size()); llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
for (py::buffer &inputBuffer : inputs) { inputTensors.reserve(inputs.size());
inputTensors.push_back(copyBufferToTensor(inputBuffer)); for (py::buffer &inputBuffer : inputs) {
} inputTensors.push_back(copyBufferToTensor(inputBuffer));
}
auto outputs = checkError(self.invoke(functionName, inputTensors), auto outputs = checkError(self.invoke(functionName, inputTensors),
"error invoking JIT function: "); "error invoking JIT function: ");
std::vector<py::array> outputArrays; std::vector<py::array> outputArrays;
outputArrays.reserve(outputs.size()); outputArrays.reserve(outputs.size());
for (Ref<Tensor> &outputTensor : outputs) { for (Ref<Tensor> &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor)); outputArrays.push_back(wrapTensorAsArray(outputTensor));
} }
return outputArrays; return outputArrays;
}, },
py::arg("function_name"), py::arg("inputs")); py::arg("function_name"), py::arg("inputs"));
// A Ref<Tensor> needs to be bound because we use it as a base for the // A Ref<Tensor> needs to be bound because we use it as a base for the
// ndarray (the array retains a reference to it). Users should not encounter // ndarray (the array retains a reference to it). Users should not encounter

View File

@ -55,7 +55,7 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
return toReturn; return toReturn;
} }
// add // add
std::map<std::string, uint64_t> AddOp::getStatistics() { std::map<std::string, uint64_t> AddOp::getStatistics() {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
@ -222,7 +222,8 @@ uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() { std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
return getConv2dBackwardStatistics(*this, 1); return getConv2dBackwardStatistics(*this, 1);
} }
std::map<std::string, uint64_t> ConvolutionBackwardOverrideableOp::getStatistics() { std::map<std::string, uint64_t>
ConvolutionBackwardOverrideableOp::getStatistics() {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp()); auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value"); auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue(); uint64_t groups = ia.getValue().getZExtValue();
@ -463,7 +464,7 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
// getMMOpStatistics(*this); // getMMOpStatistics(*this);
// } // }
std::map<std::string, uint64_t> MmOp::getStatistics() { std::map<std::string, uint64_t> MmOp::getStatistics() {
return getMMOpStatistics(*this ); return getMMOpStatistics(*this);
} }
// mul // mul

View File

@ -61,7 +61,8 @@ namespace {
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) { static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
if (val.getType() == destTy) if (val.getType() == destTy)
return val; return val;
return builder.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val) return builder
.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
.getResult(); .getResult();
} }
@ -69,11 +70,11 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
/// unknown shape. /// unknown shape.
static MemRefType getShapeErasedMemRefType(MemRefType type) { static MemRefType getShapeErasedMemRefType(MemRefType type) {
std::vector<int64_t> shape = type.getShape(); std::vector<int64_t> shape = type.getShape();
for(int i = 0; i < shape.size(); i++) { for (int i = 0; i < shape.size(); i++) {
shape[i] = -1; shape[i] = -1;
} }
return MemRefType::get(shape, type.getElementType(), return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
type.getAffineMaps(), type.getMemorySpace()); type.getMemorySpace());
} }
/// Create a type cast to memref /// Create a type cast to memref
@ -82,14 +83,12 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
if (auto memrefTy = type.dyn_cast<MemRefType>()) { if (auto memrefTy = type.dyn_cast<MemRefType>()) {
MemRefType newType = getShapeErasedMemRefType(memrefTy); MemRefType newType = getShapeErasedMemRefType(memrefTy);
return builder.create<MemRefCastOp>(val.getLoc(), return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
val, newType)
.getResult();
} }
if (auto tensorTy = type.dyn_cast<TensorType>()) { if (auto tensorTy = type.dyn_cast<TensorType>()) {
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(), auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
tensorTy.getElementType(), {}, 0); tensorTy.getElementType(), {}, 0);
return typeCast(builder, val, memRefType); return typeCast(builder, val, memRefType);
} }
return val; return val;
} }
@ -186,7 +185,7 @@ static std::string getSimplyMangledFuncName(std::string prefix,
ret = ret + sep + getSimplyMangledType(t); ret = ret + sep + getSimplyMangledType(t);
for (const Type t : operTy) { for (const Type t : operTy) {
std::string s = getSimplyMangledType(t); std::string s = getSimplyMangledType(t);
if(s.size() > 0) if (s.size() > 0)
ret = ret + sep + getSimplyMangledType(t); ret = ret + sep + getSimplyMangledType(t);
} }
ret += "_out"; ret += "_out";
@ -194,25 +193,22 @@ static std::string getSimplyMangledFuncName(std::string prefix,
return ret; return ret;
} }
static std::string getSimplyMangledFuncName(std::string prefix, static std::string getSimplyMangledFuncName(std::string prefix,
FunctionType fnTy) { FunctionType fnTy) {
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults()); return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
} }
std::string getMangledFuncName(std::string prefix, std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
FunctionType fnTy) {
return getSimplyMangledFuncName(prefix, fnTy); return getSimplyMangledFuncName(prefix, fnTy);
} }
std::string getMangledFuncName(std::string prefix, std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
ArrayRef<Type> opTys,
ArrayRef<Type> retTys) { ArrayRef<Type> retTys) {
return getSimplyMangledFuncName(prefix, opTys, retTys); return getSimplyMangledFuncName(prefix, opTys, retTys);
} }
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName, static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
ArrayRef<Value> operands, ArrayRef<Value> operands, ArrayRef<Type> retTys) {
ArrayRef<Type> retTys) {
Builder builder(module); Builder builder(module);
SmallVector<Type, 8> tys; SmallVector<Type, 8> tys;
@ -242,8 +238,8 @@ static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
class AddOpConversion_affine : public ConversionPattern { class AddOpConversion_affine : public ConversionPattern {
public: public:
explicit AddOpConversion_affine(MLIRContext *context) explicit AddOpConversion_affine(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -310,78 +306,72 @@ public:
} }
}; };
// Replace the given operation with a call to the given function. // Replace the given operation with a call to the given function.
// The function is assumed to accept memrefs and scalar types and return // The function is assumed to accept memrefs and scalar types and return
// Memrefs. Here the result types are converted back to the result types of op, // Memrefs. Here the result types are converted back to the result types of op,
// but operands are NOT converted. This allows non-standard mappings from // but operands are NOT converted. This allows non-standard mappings from
// operand types to function types. // operand types to function types.
LogicalResult LogicalResult rewriteWithVoidFunctionCallExplicit(
rewriteWithVoidFunctionCallExplicit(Operation *op, Operation *op, ArrayRef<Value> callops, ArrayRef<Value> operands,
ArrayRef<Value> callops, ConversionPatternRewriter &rewriter, std::string functionName) {
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
std::string functionName) {
auto loc = op->getLoc(); auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc); edsc::ScopedContext scope(rewriter, loc);
// The original operation types. // The original operation types.
SmallVector<Type, 8> opTys; SmallVector<Type, 8> opTys;
// Shape erased versions of the original operation types. // Shape erased versions of the original operation types.
SmallVector<Type, 8> erasedOpTys; SmallVector<Type, 8> erasedOpTys;
for (const Value &o: callops) { for (const Value &o : callops) {
Type t = o.getType(); Type t = o.getType();
opTys.push_back(t); opTys.push_back(t);
if (t.isa<MemRefType>()) if (t.isa<MemRefType>())
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>())); erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
else else
erasedOpTys.push_back(t); erasedOpTys.push_back(t);
}
std::vector<Value> newOps = callops;
SmallVector<Value, 8> newResults;
// Result types of the original operation, converted to memrefs.
SmallVector<Type, 8> retTys;
// Erased version of the return type. This is the return types of the
// generated function call.
SmallVector<Type, 8> erasedRetTys;
for (const auto &o : op->getResults()) {
Type t = o.getType();
if (t.isa<TensorType>()) {
TensorType tensorResultTy = t.cast<TensorType>();
MemRefType memRefResultTy = mlir::MemRefType::get(
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
MemRefType erasedMemRefResultTy =
getShapeErasedMemRefType(memRefResultTy);
retTys.push_back(memRefResultTy);
// assume memRefResultTy has known shape, so we don't need any
// dynamic dimensions for the alloc.
assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal);
newResults.push_back(allocVal);
} else {
return failure();
} }
}
std::vector<Value> newOps = callops; SmallVector<Type, 8> empty;
SmallVector<Value, 8> newResults; std::string mangledFunctionName =
getMangledFuncName(functionName, opTys, retTys);
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
mangledFunctionName, newOps, empty);
// Result types of the original operation, converted to memrefs. auto new_call =
SmallVector<Type, 8> retTys; callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
// Erased version of the return type. This is the return types of the
// generated function call.
SmallVector<Type, 8> erasedRetTys;
for (const auto &o: op->getResults()) {
Type t = o.getType();
if (t.isa<TensorType>()) {
TensorType tensorResultTy = t.cast<TensorType>();
MemRefType memRefResultTy =
mlir::MemRefType::get(tensorResultTy.getShape(),
tensorResultTy.getElementType(), {}, 0);
MemRefType erasedMemRefResultTy = getShapeErasedMemRefType(memRefResultTy);
retTys.push_back(memRefResultTy);
// assume memRefResultTy has known shape, so we don't need any rewriter.replaceOp(op, newResults);
// dynamic dimensions for the alloc. return success();
assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(),
memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal);
newResults.push_back(allocVal);
} else {
return failure();
}
}
SmallVector<Type, 8> empty;
std::string mangledFunctionName = getMangledFuncName(functionName, opTys, retTys);
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
mangledFunctionName,
newOps,
empty);
auto new_call = callOperation(empty,
rewriter.getSymbolRefAttr(funcOp), newOps);
rewriter.replaceOp(op, newResults);
return success();
} }
// Replace the given operation with a call to the given function. // Replace the given operation with a call to the given function.
@ -389,54 +379,53 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted // Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
// appropriately. The called function passes results of the original function // appropriately. The called function passes results of the original function
// as memref arguments at the end of the original set of operands. // as memref arguments at the end of the original set of operands.
LogicalResult LogicalResult rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter, std::string functionName) {
std::string functionName) { auto loc = op->getLoc();
auto loc = op->getLoc(); edsc::ScopedContext scope(rewriter, loc);
edsc::ScopedContext scope(rewriter, loc);
// Convert the arguments to the original call. // Convert the arguments to the original call.
SmallVector<Value, 8> callops; SmallVector<Value, 8> callops;
for (auto &o: operands) { for (auto &o : operands) {
Type t = o.getType(); Type t = o.getType();
if (t.isa<MemRefType>()) { if (t.isa<MemRefType>()) {
// Cast it to some memref type that we accept // Cast it to some memref type that we accept
callops.push_back(memRefTypeCast(rewriter, o)); callops.push_back(memRefTypeCast(rewriter, o));
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) { } else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
callops.push_back(o); callops.push_back(o);
} else if (t.isa<ATenListType>()) { } else if (t.isa<ATenListType>()) {
// FIXME: lots of assumptions here. // FIXME: lots of assumptions here.
auto unpack = [](auto &op, auto &v) -> void { auto unpack = [](auto &op, auto &v) -> void {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp()); auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
DenseElementsAttr a = DenseElementsAttr a =
co.template getAttrOfType<DenseElementsAttr>("value"); co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getIntValues()) for (auto i : a.getIntValues())
v.push_back(i.getSExtValue()); v.push_back(i.getSExtValue());
}; };
std::vector<uint64_t> values; std::vector<uint64_t> values;
unpack(o, values); unpack(o, values);
callops.push_back(constInt(values[0], 32)); callops.push_back(constInt(values[0], 32));
} else { } else {
return failure(); return failure();
}
} }
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, functionName); }
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
functionName);
} }
/// Lower Add /// Lower Add
template<typename Op> template <typename Op>
class ATenFunctionCallConversion : public ConversionPattern { class ATenFunctionCallConversion : public ConversionPattern {
public: public:
explicit ATenFunctionCallConversion(MLIRContext *context) explicit ATenFunctionCallConversion(MLIRContext *context)
: ConversionPattern(Op::getOperationName(), 1, context) { : ConversionPattern(Op::getOperationName(), 1, context) {}
}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, Op::getFunctionConversionName()); return rewriteWithFunctionCall(op, operands, rewriter,
Op::getFunctionConversionName());
} }
}; };
@ -444,8 +433,8 @@ public:
class ConstantOpConversion : public ConversionPattern { class ConstantOpConversion : public ConversionPattern {
public: public:
explicit ConstantOpConversion(MLIRContext *context) explicit ConstantOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -459,14 +448,15 @@ public:
Type t = result.getType(); Type t = result.getType();
if (t.isa<IntegerType>()) { if (t.isa<IntegerType>()) {
auto it = t.cast<IntegerType>(); auto it = t.cast<IntegerType>();
if(it.getWidth() > 1) { if (it.getWidth() > 1) {
auto a = op->getAttrOfType<IntegerAttr>("value"); auto a = op->getAttrOfType<IntegerAttr>("value");
SmallVector<Value, 8> newValues {rewriter.create<mlir::ConstantOp>(loc, a)}; SmallVector<Value, 8> newValues{
rewriter.create<mlir::ConstantOp>(loc, a)};
rewriter.replaceOp(op, newValues); rewriter.replaceOp(op, newValues);
return success(); return success();
} else { } else {
auto a = op->getAttrOfType<BoolAttr>("value"); auto a = op->getAttrOfType<BoolAttr>("value");
SmallVector<Value, 8> newValues {constInt(a.getValue(), it.getWidth())}; SmallVector<Value, 8> newValues{constInt(a.getValue(), it.getWidth())};
rewriter.replaceOp(op, newValues); rewriter.replaceOp(op, newValues);
return success(); return success();
} }
@ -485,8 +475,8 @@ public:
class AddOpConversion : public ConversionPattern { class AddOpConversion : public ConversionPattern {
public: public:
explicit AddOpConversion(MLIRContext *context) explicit AddOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -513,8 +503,8 @@ public:
class AsStridedOpConversion : public ConversionPattern { class AsStridedOpConversion : public ConversionPattern {
public: public:
explicit AsStridedOpConversion(MLIRContext *context) explicit AsStridedOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -527,7 +517,8 @@ public:
// construct the shape argument // construct the shape argument
std::vector<Value> shape; std::vector<Value> shape;
std::vector<int64_t> result_shape; std::vector<int64_t> result_shape;
auto co0 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp()); auto co0 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a0 = DenseElementsAttr a0 =
co0.template getAttrOfType<DenseElementsAttr>("value"); co0.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a0.getAttributeValues()) for (auto i : a0.getAttributeValues())
@ -539,7 +530,8 @@ public:
// construct the stride argument // construct the stride argument
std::vector<Value> stride; std::vector<Value> stride;
auto co1 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp()); auto co1 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
DenseElementsAttr a1 = DenseElementsAttr a1 =
co1.template getAttrOfType<DenseElementsAttr>("value"); co1.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a1.getAttributeValues()) for (auto i : a1.getAttributeValues())
@ -551,19 +543,21 @@ public:
APInt offset(32, 0); APInt offset(32, 0);
if (operands.size() > 3) { if (operands.size() > 3) {
auto co2 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp()); auto co2 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
auto ia2 = co2.getAttrOfType<IntegerAttr>("value"); auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
offset = ia2.getValue(); offset = ia2.getValue();
} }
SmallVector<Value, 8> callops{xVal, shape[0], SmallVector<Value, 8> callops{
shape[1], shape[2], xVal, shape[0],
shape[3], stride[0], shape[1], shape[2],
stride[1], stride[2], shape[3], stride[0],
stride[3], constInt(offset.getSExtValue(), 32)}; stride[1], stride[2],
stride[3], constInt(offset.getSExtValue(), 32)};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "as_strided"); "as_strided");
} }
}; };
@ -571,8 +565,8 @@ public:
class BatchNormOpConversion : public ConversionPattern { class BatchNormOpConversion : public ConversionPattern {
public: public:
explicit BatchNormOpConversion(MLIRContext *context) explicit BatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -585,8 +579,8 @@ public:
class ConvolutionOpConversion : public ConversionPattern { class ConvolutionOpConversion : public ConversionPattern {
public: public:
explicit ConvolutionOpConversion(MLIRContext *context) explicit ConvolutionOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -614,8 +608,8 @@ public:
class DivOpConversion : public ConversionPattern { class DivOpConversion : public ConversionPattern {
public: public:
explicit DivOpConversion(MLIRContext *context) explicit DivOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -627,8 +621,8 @@ public:
class LogSoftmaxOpConversion : public ConversionPattern { class LogSoftmaxOpConversion : public ConversionPattern {
public: public:
explicit LogSoftmaxOpConversion(MLIRContext *context) explicit LogSoftmaxOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -657,8 +651,8 @@ public:
class MaxPoolOpConversion : public ConversionPattern { class MaxPoolOpConversion : public ConversionPattern {
public: public:
explicit MaxPoolOpConversion(MLIRContext *context) explicit MaxPoolOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -678,7 +672,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices"); return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices");
} }
}; };
@ -686,14 +681,15 @@ public:
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern { class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
public: public:
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context) explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
: ConversionPattern( : ConversionPattern(mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::
mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::getOperationName(), 1, getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices_backward");
} }
}; };
@ -701,7 +697,8 @@ public:
class MMOpConversion : public ConversionPattern { class MMOpConversion : public ConversionPattern {
public: public:
explicit MMOpConversion(MLIRContext *context) explicit MMOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1, context) {} : ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -714,8 +711,8 @@ public:
class MulOpConversion : public ConversionPattern { class MulOpConversion : public ConversionPattern {
public: public:
explicit MulOpConversion(MLIRContext *context) explicit MulOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -728,8 +725,9 @@ public:
class NativeBatchNormOpConversion : public ConversionPattern { class NativeBatchNormOpConversion : public ConversionPattern {
public: public:
explicit NativeBatchNormOpConversion(MLIRContext *context) explicit NativeBatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -742,13 +740,15 @@ public:
class NllLoss2dBackwardOpConversion : public ConversionPattern { class NllLoss2dBackwardOpConversion : public ConversionPattern {
public: public:
explicit NllLoss2dBackwardOpConversion(MLIRContext *context) explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_backward");
} }
}; };
@ -756,13 +756,15 @@ public:
class NllLoss2dForwardOpConversion : public ConversionPattern { class NllLoss2dForwardOpConversion : public ConversionPattern {
public: public:
explicit NllLoss2dForwardOpConversion(MLIRContext *context) explicit NllLoss2dForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_forward"); return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_forward");
} }
}; };
@ -770,8 +772,9 @@ public:
class NllLossBackwardOpConversion : public ConversionPattern { class NllLossBackwardOpConversion : public ConversionPattern {
public: public:
explicit NllLossBackwardOpConversion(MLIRContext *context) explicit NllLossBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -784,13 +787,15 @@ public:
class NllLossForwardOpConversion : public ConversionPattern { class NllLossForwardOpConversion : public ConversionPattern {
public: public:
explicit NllLossForwardOpConversion(MLIRContext *context) explicit NllLossForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1, : ConversionPattern(
context) {} mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward"); } return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward");
}
}; };
/// Lower ReLU /// Lower ReLU
@ -811,13 +816,15 @@ public:
class ThresholdBackwardOpConversion : public ConversionPattern { class ThresholdBackwardOpConversion : public ConversionPattern {
public: public:
explicit ThresholdBackwardOpConversion(MLIRContext *context) explicit ThresholdBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "threshold_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"threshold_backward");
} }
}; };
@ -825,7 +832,8 @@ public:
class TransposeOpConversion : public ConversionPattern { class TransposeOpConversion : public ConversionPattern {
public: public:
explicit TransposeOpConversion(MLIRContext *context) explicit TransposeOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1, context) {} : ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -849,21 +857,22 @@ public:
Value xVal = memRefTypeCast(rewriter, operands[0]); Value xVal = memRefTypeCast(rewriter, operands[0]);
// construct the shape argument // construct the shape argument
SmallVector<Value, 8> shape; SmallVector<Value, 8> shape;
auto co = dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp()); auto co =
dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value"); DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getAttributeValues()) for (auto i : a.getAttributeValues())
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i)); shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
// pad out the shape with -1 to make it 4d // pad out the shape with -1 to make it 4d
while (shape.size() < 4) while (shape.size() < 4)
shape.push_back(constInt(-1, 32)); shape.push_back(constInt(-1, 32));
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]}; SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "view"); return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
"view");
} }
}; };
@ -896,9 +905,8 @@ struct ATenLoweringPass
// c++ patterns // c++ patterns
acapPatterns.insert< acapPatterns.insert<
ConstantOpConversion, ConstantOpConversion, AddOpConversion, ConvolutionOpConversion,
AddOpConversion, ConvolutionOpConversion, ReLUOpConversion, ReLUOpConversion, TransposeOpConversion, BatchNormOpConversion,
TransposeOpConversion, BatchNormOpConversion,
NativeBatchNormOpConversion, MaxPoolOpConversion, NativeBatchNormOpConversion, MaxPoolOpConversion,
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion, MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
MulOpConversion, MMOpConversion, AsStridedOpConversion, MulOpConversion, MMOpConversion, AsStridedOpConversion,

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/ATen/ATenToStd.h" #include "npcomp/Dialect/ATen/ATenToStd.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;

View File

@ -115,8 +115,8 @@ std::string LivenessReport::emitJSONReport() {
for (auto v : vlist) { for (auto v : vlist) {
int64_t vol = getTensorVolume(v.getType()); int64_t vol = getTensorVolume(v.getType());
if (v.getDefiningOp()) { if (v.getDefiningOp()) {
if (auto a = v.getDefiningOp()->getAttrOfType<StringAttr>( if (auto a =
"layer_name")) { v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
auto definingOp = v.getDefiningOp(); auto definingOp = v.getDefiningOp();
auto ld = layerDetail.getInteger(a.getValue().str()); auto ld = layerDetail.getInteger(a.getValue().str());
if (ld) if (ld)

View File

@ -32,28 +32,29 @@ public:
auto op = opBuilder.create<scf::YieldOp>(loc, yields); auto op = opBuilder.create<scf::YieldOp>(loc, yields);
return op.getOperation(); return op.getOperation();
}) })
.def("scf_if_op", .def(
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes, "scf_if_op",
PyValue cond, bool withElseRegion) { [](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); PyValue cond, bool withElseRegion) {
Location loc = self.pyOpBuilder.getCurrentLoc(); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(), Location loc = self.pyOpBuilder.getCurrentLoc();
pyResultTypes.end()); llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond, pyResultTypes.end());
withElseRegion); auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
if (withElseRegion) { withElseRegion);
return py::make_tuple( if (withElseRegion) {
PyOperationRef(op), return py::make_tuple(
op.getThenBodyBuilder().saveInsertionPoint(), PyOperationRef(op),
op.getElseBodyBuilder().saveInsertionPoint()); op.getThenBodyBuilder().saveInsertionPoint(),
} else { op.getElseBodyBuilder().saveInsertionPoint());
return py::make_tuple( } else {
PyOperationRef(op), return py::make_tuple(
op.getThenBodyBuilder().saveInsertionPoint()); PyOperationRef(op),
} op.getThenBodyBuilder().saveInsertionPoint());
}, }
py::arg("result_types"), py::arg("cond"), },
py::arg("with_else_region") = false); py::arg("result_types"), py::arg("cond"),
py::arg("with_else_region") = false);
} }
}; };

View File

@ -65,13 +65,14 @@ void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
"front", "front",
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); }) [](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
.def("__len__", [](ThisTy &self) { return self.list.size(); }) .def("__len__", [](ThisTy &self) { return self.list.size(); })
.def("__iter__", .def(
[](ThisTy &self) { "__iter__",
PyItemIterator begin(self.list.begin()); [](ThisTy &self) {
PyItemIterator end(self.list.end()); PyItemIterator begin(self.list.begin());
return py::make_iterator(begin, end); PyItemIterator end(self.list.end());
}, return py::make_iterator(begin, end);
py::keep_alive<0, 1>()); },
py::keep_alive<0, 1>());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -194,79 +195,81 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> { [](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
return self.context.shared_from_this(); return self.context.shared_from_this();
}) })
.def("op", .def(
[](PyDialectHelper &self, const std::string &opNameStr, "op",
std::vector<PyType> pyResultTypes, [](PyDialectHelper &self, const std::string &opNameStr,
std::vector<PyValue> pyOperands, std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef { llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = self.pyOpBuilder.getCurrentLoc(); Location loc = self.pyOpBuilder.getCurrentLoc();
OperationName opName(opNameStr, opBuilder.getContext()); OperationName opName(opNameStr, opBuilder.getContext());
SmallVector<Type, 4> types(pyResultTypes.begin(), SmallVector<Type, 4> types(pyResultTypes.begin(),
pyResultTypes.end()); pyResultTypes.end());
SmallVector<Value, 4> operands(pyOperands.begin(), SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end()); pyOperands.end());
MutableDictionaryAttr attrList; MutableDictionaryAttr attrList;
if (attrs) { if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>(); auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) { if (!dictAttrs) {
throw py::raiseValueError( throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr"); "Expected `attrs` to be a DictionaryAttr");
} }
attrList = MutableDictionaryAttr(dictAttrs); attrList = MutableDictionaryAttr(dictAttrs);
} }
Operation *op = Operation *op =
Operation::create(loc, opName, types, operands, attrList); Operation::create(loc, opName, types, operands, attrList);
opBuilder.insert(op); opBuilder.insert(op);
return op; return op;
}, },
py::arg("op_name"), py::arg("result_types"), py::arg("operands"), py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>()) py::arg("attrs") = llvm::Optional<PyAttribute>())
.def("func_op", .def(
[](PyDialectHelper &self, const std::string &name, PyType type, "func_op",
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) { [](PyDialectHelper &self, const std::string &name, PyType type,
auto functionType = type.type.dyn_cast_or_null<FunctionType>(); bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
if (!functionType) { auto functionType = type.type.dyn_cast_or_null<FunctionType>();
throw py::raiseValueError("Illegal function type"); if (!functionType) {
} throw py::raiseValueError("Illegal function type");
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); }
Location loc = self.pyOpBuilder.getCurrentLoc(); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
// TODO: Dedup attr creation from op(). Location loc = self.pyOpBuilder.getCurrentLoc();
MutableDictionaryAttr attrList; // TODO: Dedup attr creation from op().
if (attrs) { MutableDictionaryAttr attrList;
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>(); if (attrs) {
if (!dictAttrs) { auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
throw py::raiseValueError( if (!dictAttrs) {
"Expected `attrs` to be a DictionaryAttr"); throw py::raiseValueError(
} "Expected `attrs` to be a DictionaryAttr");
attrList = MutableDictionaryAttr(dictAttrs); }
} attrList = MutableDictionaryAttr(dictAttrs);
FuncOp op = }
opBuilder.create<FuncOp>(loc, StringRef(name), functionType, FuncOp op =
/*attrs=*/attrList.getAttrs()); opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
if (createEntryBlock) { /*attrs=*/attrList.getAttrs());
Block *entryBlock = new Block(); if (createEntryBlock) {
entryBlock->addArguments(functionType.getInputs()); Block *entryBlock = new Block();
op.getBody().push_back(entryBlock); entryBlock->addArguments(functionType.getInputs());
opBuilder.setInsertionPointToStart(entryBlock); op.getBody().push_back(entryBlock);
} opBuilder.setInsertionPointToStart(entryBlock);
return PyOperationRef(op); }
}, return PyOperationRef(op);
py::arg("name"), py::arg("type"), },
py::arg("create_entry_block") = false, py::arg("name"), py::arg("type"),
py::arg("attrs") = llvm::Optional<PyAttribute>(), py::arg("create_entry_block") = false,
R"(Creates a new `func` op, optionally creating an entry block. py::arg("attrs") = llvm::Optional<PyAttribute>(),
R"(Creates a new `func` op, optionally creating an entry block.
If an entry block is created, the builder will be positioned If an entry block is created, the builder will be positioned
to its start.)") to its start.)")
.def("select_op", .def(
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue, "select_op",
PyValue falseValue) -> PyOperationRef { [](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); PyValue falseValue) -> PyOperationRef {
Location loc = self.pyOpBuilder.getCurrentLoc(); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
return PyOperationRef(opBuilder.create<SelectOp>( Location loc = self.pyOpBuilder.getCurrentLoc();
loc, conditionValue, trueValue, falseValue)); return PyOperationRef(opBuilder.create<SelectOp>(
}, loc, conditionValue, trueValue, falseValue));
py::arg("condition"), py::arg("true_value"), py::arg("false_value")) },
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
.def("return_op", .def("return_op",
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) { [](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
@ -288,11 +291,12 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> PyType { [](PyDialectHelper &self) -> PyType {
return IndexType::get(self.getContext()); return IndexType::get(self.getContext());
}) })
.def("integer_type", .def(
[](PyDialectHelper &self, unsigned width) -> PyType { "integer_type",
return IntegerType::get(width, self.getContext()); [](PyDialectHelper &self, unsigned width) -> PyType {
}, return IntegerType::get(width, self.getContext());
py::arg("width") = 32) },
py::arg("width") = 32)
.def_property_readonly("i1_type", .def_property_readonly("i1_type",
[](PyDialectHelper &self) -> PyType { [](PyDialectHelper &self) -> PyType {
return IntegerType::get(1, self.getContext()); return IntegerType::get(1, self.getContext());
@ -319,20 +323,21 @@ void PyDialectHelper::bind(py::module m) {
return FloatType::get(StandardTypes::F64, return FloatType::get(StandardTypes::F64,
self.getContext()); self.getContext());
}) })
.def("tensor_type", .def(
[](PyDialectHelper &self, PyType elementType, "tensor_type",
llvm::Optional<std::vector<int64_t>> shape) -> PyType { [](PyDialectHelper &self, PyType elementType,
if (!elementType.type) { llvm::Optional<std::vector<int64_t>> shape) -> PyType {
throw py::raiseValueError("Null element type"); if (!elementType.type) {
} throw py::raiseValueError("Null element type");
if (shape) { }
return RankedTensorType::get(*shape, elementType.type); if (shape) {
} else { return RankedTensorType::get(*shape, elementType.type);
return UnrankedTensorType::get(elementType.type); } else {
} return UnrankedTensorType::get(elementType.type);
}, }
py::arg("element_type"), },
py::arg("shape") = llvm::Optional<std::vector<int64_t>>()) py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def("function_type", .def("function_type",
[](PyDialectHelper &self, std::vector<PyType> inputs, [](PyDialectHelper &self, std::vector<PyType> inputs,
std::vector<PyType> results) -> PyType { std::vector<PyType> results) -> PyType {
@ -367,21 +372,24 @@ void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library"; m.doc() = "Python bindings for constructs in the mlir/IR library";
// Globals. // Globals.
m.def("emit_error", m.def(
[](PyAttribute loc, std::string message) { "emit_error",
emitDiagnostic(DiagnosticSeverity::Error, loc, message); [](PyAttribute loc, std::string message) {
}, emitDiagnostic(DiagnosticSeverity::Error, loc, message);
py::arg("loc"), py::arg("message")); },
m.def("emit_warning", py::arg("loc"), py::arg("message"));
[](PyAttribute loc, std::string message) { m.def(
emitDiagnostic(DiagnosticSeverity::Warning, loc, message); "emit_warning",
}, [](PyAttribute loc, std::string message) {
py::arg("loc"), py::arg("message")); emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
m.def("emit_remark", },
[](PyAttribute loc, std::string message) { py::arg("loc"), py::arg("message"));
emitDiagnostic(DiagnosticSeverity::Remark, loc, message); m.def(
}, "emit_remark",
py::arg("loc"), py::arg("message")); [](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
},
py::arg("loc"), py::arg("message"));
// Python only types. // Python only types.
PyDialectHelper::bind(m); PyDialectHelper::bind(m);
@ -426,25 +434,27 @@ void PyContext::bind(py::module m) {
return PyModuleOp(self.shared_from_this(), m); return PyModuleOp(self.shared_from_this(), m);
}) })
.def("parse_asm", &PyContext::parseAsm) .def("parse_asm", &PyContext::parseAsm)
.def("new_builder", .def(
[](PyContext &self) { "new_builder",
// Note: we collapse the Builder and OpBuilder into one because [](PyContext &self) {
// there is little reason to expose the inheritance hierarchy to // Note: we collapse the Builder and OpBuilder into one because
// Python. // there is little reason to expose the inheritance hierarchy to
return PyOpBuilder(self); // Python.
}, return PyOpBuilder(self);
py::keep_alive<0, 1>()) },
py::keep_alive<0, 1>())
.def("identifier", .def("identifier",
[](PyContext &self, std::string s) -> PyIdentifier { [](PyContext &self, std::string s) -> PyIdentifier {
return Identifier::get(s, &self.context); return Identifier::get(s, &self.context);
}) })
.def("file_line_col_loc_attr", .def(
[](PyContext &self, PyIdentifier filename, unsigned line, "file_line_col_loc_attr",
unsigned column) -> PyAttribute { [](PyContext &self, PyIdentifier filename, unsigned line,
return static_cast<LocationAttr>(FileLineColLoc::get( unsigned column) -> PyAttribute {
filename.identifier, line, column, &self.context)); return static_cast<LocationAttr>(FileLineColLoc::get(
}, filename.identifier, line, column, &self.context));
py::arg("filename"), py::arg("line"), py::arg("column")) },
py::arg("filename"), py::arg("line"), py::arg("column"))
// Salient functions from Builder. // Salient functions from Builder.
.def("parse_type", .def("parse_type",
[](PyContext &self, const std::string &asmText) { [](PyContext &self, const std::string &asmText) {
@ -456,14 +466,15 @@ void PyContext::bind(py::module m) {
} }
return PyType(t); return PyType(t);
}) })
.def("integer_attr", .def(
[](PyContext &self, PyType type, int64_t value) -> PyAttribute { "integer_attr",
if (!type.type.isa<IntegerType>()) { [](PyContext &self, PyType type, int64_t value) -> PyAttribute {
throw py::raiseValueError("Expected IntegerType"); if (!type.type.isa<IntegerType>()) {
} throw py::raiseValueError("Expected IntegerType");
return IntegerAttr::get(type.type, value); }
}, return IntegerAttr::get(type.type, value);
py::arg("type"), py::arg("value")) },
py::arg("type"), py::arg("value"))
.def("float_attr", .def("float_attr",
[](PyContext &self, PyType type, double value) -> PyAttribute { [](PyContext &self, PyType type, double value) -> PyAttribute {
if (!type.type.isa<FloatType>()) { if (!type.type.isa<FloatType>()) {
@ -512,20 +523,20 @@ void PyContext::bind(py::module m) {
} }
return ArrayAttr::get(attrs, &self.context); return ArrayAttr::get(attrs, &self.context);
}) })
.def("dense_elements_attr", .def(
[](PyContext &self, py::buffer array) -> PyAttribute { "dense_elements_attr",
// Request a contiguous view. [](PyContext &self, py::buffer array) -> PyAttribute {
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; // Request a contiguous view.
Py_buffer *view = new Py_buffer(); int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { Py_buffer *view = new Py_buffer();
delete view; if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
throw py::error_already_set(); delete view;
} throw py::error_already_set();
py::buffer_info array_info(view); }
return createDenseElementsAttrFromBuffer(&self.context, py::buffer_info array_info(view);
array_info); return createDenseElementsAttrFromBuffer(&self.context, array_info);
}, },
py::arg("array")) py::arg("array"))
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute { .def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
return UnitAttr::get(&self.context); return UnitAttr::get(&self.context);
}); });
@ -905,67 +916,74 @@ void PyBaseOpBuilder::bind(py::module m) {
void PyOpBuilder::bind(py::module m) { void PyOpBuilder::bind(py::module m) {
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder") py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
.def(py::init<PyContext &>(), py::keep_alive<1, 2>()) .def(py::init<PyContext &>(), py::keep_alive<1, 2>())
.def_property("current_loc", .def_property(
[](PyOpBuilder &self) -> PyAttribute { "current_loc",
return static_cast<Attribute>(self.getCurrentLoc()); [](PyOpBuilder &self) -> PyAttribute {
}, return static_cast<Attribute>(self.getCurrentLoc());
[](PyOpBuilder &self, PyAttribute attr) { },
auto loc_attr = [](PyOpBuilder &self, PyAttribute attr) {
attr.attr.dyn_cast_or_null<LocationAttr>(); auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
if (!loc_attr) { if (!loc_attr) {
throw py::raiseValueError("Expected a LocationAttr"); throw py::raiseValueError("Expected a LocationAttr");
} }
self.setCurrentLoc(Location(loc_attr)); self.setCurrentLoc(Location(loc_attr));
}) })
.def_property("insertion_point", .def_property(
[](PyOpBuilder &self) { "insertion_point",
return self.getBuilder(true).saveInsertionPoint(); [](PyOpBuilder &self) {
}, return self.getBuilder(true).saveInsertionPoint();
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) { },
self.getBuilder(false).restoreInsertionPoint(ip); [](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
}) self.getBuilder(false).restoreInsertionPoint(ip);
.def("set_file_line_col", })
[](PyOpBuilder &self, PyIdentifier filename, unsigned line, .def(
unsigned column) { "set_file_line_col",
Location loc = FileLineColLoc::get(filename.identifier, line, [](PyOpBuilder &self, PyIdentifier filename, unsigned line,
column, self.getContext()); unsigned column) {
self.setCurrentLoc(loc); Location loc = FileLineColLoc::get(filename.identifier, line,
}, column, self.getContext());
py::arg("filename"), py::arg("line"), py::arg("column"), self.setCurrentLoc(loc);
"Shortcut to set a FileLineCol current location") },
py::arg("filename"), py::arg("line"), py::arg("column"),
"Shortcut to set a FileLineCol current location")
.def("clear_insertion_point", .def("clear_insertion_point",
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); }) [](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
.def("insert_op_before", .def(
[](PyOpBuilder &self, PyBaseOperation &pyOp) { "insert_op_before",
Operation *op = pyOp.getOperation(); [](PyOpBuilder &self, PyBaseOperation &pyOp) {
self.builder.setInsertionPoint(op); Operation *op = pyOp.getOperation();
}, self.builder.setInsertionPoint(op);
"Sets the insertion point to just before the specified op.") },
.def("insert_op_after", "Sets the insertion point to just before the specified op.")
[](PyOpBuilder &self, PyBaseOperation &pyOp) { .def(
Operation *op = pyOp.getOperation(); "insert_op_after",
self.builder.setInsertionPointAfter(op); [](PyOpBuilder &self, PyBaseOperation &pyOp) {
}, Operation *op = pyOp.getOperation();
"Sets the insertion point to just after the specified op.") self.builder.setInsertionPointAfter(op);
.def("insert_block_start", },
[](PyOpBuilder &self, PyBlockRef block) { "Sets the insertion point to just after the specified op.")
self.builder.setInsertionPointToStart(&block.block); .def(
}, "insert_block_start",
"Sets the insertion point to the start of the block.") [](PyOpBuilder &self, PyBlockRef block) {
.def("insert_block_end", self.builder.setInsertionPointToStart(&block.block);
[](PyOpBuilder &self, PyBlockRef block) { },
self.builder.setInsertionPointToEnd(&block.block); "Sets the insertion point to the start of the block.")
}, .def(
"Sets the insertion point to the end of the block.") "insert_block_end",
.def("insert_before_terminator", [](PyOpBuilder &self, PyBlockRef block) {
[](PyOpBuilder &self, PyBlockRef block) { self.builder.setInsertionPointToEnd(&block.block);
auto *terminator = block.block.getTerminator(); },
if (!terminator) { "Sets the insertion point to the end of the block.")
throw py::raiseValueError("Block has no terminator"); .def(
} "insert_before_terminator",
self.builder.setInsertionPoint(terminator); [](PyOpBuilder &self, PyBlockRef block) {
}, auto *terminator = block.block.getTerminator();
"Sets the insertion point to just before the block terminator."); if (!terminator) {
throw py::raiseValueError("Block has no terminator");
}
self.builder.setInsertionPoint(terminator);
},
"Sets the insertion point to just before the block terminator.");
} }
} // namespace mlir } // namespace mlir

View File

@ -32,13 +32,14 @@ void PyPassManager::bind(py::module m) {
py::class_<PyPassManager>(m, "PassManager") py::class_<PyPassManager>(m, "PassManager")
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"), .def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
py::arg("verifyModules") = true) py::arg("verifyModules") = true)
.def("enableCrashReproducerGeneration", .def(
[](PyPassManager &self, std::string outputFile, "enableCrashReproducerGeneration",
bool genLocalReproducer) { [](PyPassManager &self, std::string outputFile,
self.passManager.enableCrashReproducerGeneration( bool genLocalReproducer) {
outputFile, genLocalReproducer); self.passManager.enableCrashReproducerGeneration(
}, outputFile, genLocalReproducer);
py::arg("outputFile"), py::arg("genLocalReproducer") = false) },
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
.def("__len__", .def("__len__",
[](PyPassManager &self) { return self.passManager.size(); }) [](PyPassManager &self) { return self.passManager.size(); })
.def("__str__", .def("__str__",

View File

@ -48,18 +48,19 @@ public:
return Basicpy::NoneType::get( return Basicpy::NoneType::get(
self.getContext()); self.getContext());
}) })
.def("basicpy_SlotObject_type", .def(
[](BasicpyDialectHelper &self, std::string className, "basicpy_SlotObject_type",
py::args pySlotTypes) -> PyType { [](BasicpyDialectHelper &self, std::string className,
SmallVector<Type, 4> slotTypes; py::args pySlotTypes) -> PyType {
for (auto pySlotType : pySlotTypes) { SmallVector<Type, 4> slotTypes;
slotTypes.push_back(pySlotType.cast<PyType>()); for (auto pySlotType : pySlotTypes) {
} slotTypes.push_back(pySlotType.cast<PyType>());
auto classNameAttr = }
StringAttr::get(className, self.getContext()); auto classNameAttr =
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes); StringAttr::get(className, self.getContext());
}, return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
py::arg("className")) },
py::arg("className"))
.def_property_readonly("basicpy_StrType", .def_property_readonly("basicpy_StrType",
[](BasicpyDialectHelper &self) -> PyType { [](BasicpyDialectHelper &self) -> PyType {
return Basicpy::StrType::get( return Basicpy::StrType::get(