Format sources.

pull/35/head
Stella Laurenzo 2020-08-27 14:47:49 -07:00
parent de38caa547
commit fc4f374345
13 changed files with 537 additions and 502 deletions

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc" #include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
} // namespace aten
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir
#endif #endif

View File

@ -11,9 +11,9 @@
#include "npcomp/Dialect/ATen/ATenDialect.h" #include "npcomp/Dialect/ATen/ATenDialect.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "aten-op-stats" #define DEBUG_TYPE "aten-op-stats"
@ -25,7 +25,8 @@ namespace NPCOMP {
namespace aten { namespace aten {
// Return the op statistics for conv2d-like operations. // Return the op statistics for conv2d-like operations.
template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) { template <class T>
std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
@ -66,11 +67,13 @@ template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uin
} }
// Return the op statistics for conv2dBackward-like operations. // Return the op statistics for conv2dBackward-like operations.
template<typename T> template <typename T>
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t groups) { std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op,
uint64_t groups) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
TensorType dx_out_resultTy = op.getResult(0).getType().template cast<TensorType>(); TensorType dx_out_resultTy =
op.getResult(0).getType().template cast<TensorType>();
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy); uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>(); TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
@ -119,7 +122,6 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
return toReturn; return toReturn;
} }
// Return a model of the number of bytes needed to represent the operand of // Return a model of the number of bytes needed to represent the operand of
// the given convolution-like operation with the given index. The shape is // the given convolution-like operation with the given index. The shape is
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO: // assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
@ -222,8 +224,7 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
} }
// Return the op statistics for matrixmultiply-like operations. // Return the op statistics for matrixmultiply-like operations.
template<typename T> template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;

View File

@ -20,7 +20,7 @@ namespace aten {
// #define GEN_PASS_CLASSES // #define GEN_PASS_CLASSES
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc" // #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
void registerATenPasses(); void registerATenPasses();
} // namespace aten } // namespace aten
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -44,14 +44,16 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
); );
}); });
m.def("build_flow_transform_pass_pipeline", m.def(
"build_flow_transform_pass_pipeline",
[](PyPassManager &pm) { [](PyPassManager &pm) {
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline( mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
pm.passManager); pm.passManager);
}, },
py::arg("pm"), py::arg("pm"),
py::doc("Builds a pass pipeline for top-level Flow import")); py::doc("Builds a pass pipeline for top-level Flow import"));
m.def("build_hal_transform_pass_pipeline", m.def(
"build_hal_transform_pass_pipeline",
[](PyPassManager &pm, std::vector<std::string> targetBackends) { [](PyPassManager &pm, std::vector<std::string> targetBackends) {
mlir::iree_compiler::IREE::HAL::TargetOptions options; mlir::iree_compiler::IREE::HAL::TargetOptions options;
if (targetBackends.empty()) { if (targetBackends.empty()) {
@ -60,12 +62,13 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
} else { } else {
options.targets = std::move(targetBackends); options.targets = std::move(targetBackends);
} }
iree_compiler::IREE::HAL::buildHALTransformPassPipeline( iree_compiler::IREE::HAL::buildHALTransformPassPipeline(pm.passManager,
pm.passManager, options); options);
}, },
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(), py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
py::doc("Builds a pass pipeline for top-level Flow import")); py::doc("Builds a pass pipeline for top-level Flow import"));
m.def("build_vm_transform_pass_pipeline", m.def(
"build_vm_transform_pass_pipeline",
[](PyPassManager &pm) { [](PyPassManager &pm) {
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline( mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
pm.passManager); pm.passManager);

View File

@ -89,19 +89,20 @@ void npcomp::python::defineBackendRefJitModule(py::module m) {
JITModule::buildBackendCompilationPipeline(pm.passManager); JITModule::buildBackendCompilationPipeline(pm.passManager);
}); });
py::class_<JITModule>(m, "JITModule") py::class_<JITModule>(m, "JITModule")
.def_static("from_compiled_module", .def_static(
"from_compiled_module",
[](PyModuleOp module, std::vector<std::string> pySharedLibs) [](PyModuleOp module, std::vector<std::string> pySharedLibs)
-> std::unique_ptr<JITModule> { -> std::unique_ptr<JITModule> {
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(), SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
pySharedLibs.end()); pySharedLibs.end());
auto jitModule = auto jitModule = checkError(
checkError(JITModule::fromCompiledModule( JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
module.moduleOp, sharedLibs),
"error creating JITModule: "); "error creating JITModule: ");
return jitModule; return jitModule;
}, },
py::arg("module"), py::arg("shared_libs")) py::arg("module"), py::arg("shared_libs"))
.def("invoke", .def(
"invoke",
[](JITModule &self, std::string functionName, [](JITModule &self, std::string functionName,
std::vector<py::buffer> inputs) { std::vector<py::buffer> inputs) {
// Prepare inputs. // Prepare inputs.

View File

@ -55,7 +55,7 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
return toReturn; return toReturn;
} }
// add // add
std::map<std::string, uint64_t> AddOp::getStatistics() { std::map<std::string, uint64_t> AddOp::getStatistics() {
std::map<std::string, uint64_t> toReturn; std::map<std::string, uint64_t> toReturn;
@ -222,7 +222,8 @@ uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() { std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
return getConv2dBackwardStatistics(*this, 1); return getConv2dBackwardStatistics(*this, 1);
} }
std::map<std::string, uint64_t> ConvolutionBackwardOverrideableOp::getStatistics() { std::map<std::string, uint64_t>
ConvolutionBackwardOverrideableOp::getStatistics() {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp()); auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value"); auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue(); uint64_t groups = ia.getValue().getZExtValue();
@ -463,7 +464,7 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
// getMMOpStatistics(*this); // getMMOpStatistics(*this);
// } // }
std::map<std::string, uint64_t> MmOp::getStatistics() { std::map<std::string, uint64_t> MmOp::getStatistics() {
return getMMOpStatistics(*this ); return getMMOpStatistics(*this);
} }
// mul // mul

View File

@ -61,7 +61,8 @@ namespace {
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) { static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
if (val.getType() == destTy) if (val.getType() == destTy)
return val; return val;
return builder.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val) return builder
.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
.getResult(); .getResult();
} }
@ -69,11 +70,11 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
/// unknown shape. /// unknown shape.
static MemRefType getShapeErasedMemRefType(MemRefType type) { static MemRefType getShapeErasedMemRefType(MemRefType type) {
std::vector<int64_t> shape = type.getShape(); std::vector<int64_t> shape = type.getShape();
for(int i = 0; i < shape.size(); i++) { for (int i = 0; i < shape.size(); i++) {
shape[i] = -1; shape[i] = -1;
} }
return MemRefType::get(shape, type.getElementType(), return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
type.getAffineMaps(), type.getMemorySpace()); type.getMemorySpace());
} }
/// Create a type cast to memref /// Create a type cast to memref
@ -82,9 +83,7 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
if (auto memrefTy = type.dyn_cast<MemRefType>()) { if (auto memrefTy = type.dyn_cast<MemRefType>()) {
MemRefType newType = getShapeErasedMemRefType(memrefTy); MemRefType newType = getShapeErasedMemRefType(memrefTy);
return builder.create<MemRefCastOp>(val.getLoc(), return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
val, newType)
.getResult();
} }
if (auto tensorTy = type.dyn_cast<TensorType>()) { if (auto tensorTy = type.dyn_cast<TensorType>()) {
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(), auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
@ -186,7 +185,7 @@ static std::string getSimplyMangledFuncName(std::string prefix,
ret = ret + sep + getSimplyMangledType(t); ret = ret + sep + getSimplyMangledType(t);
for (const Type t : operTy) { for (const Type t : operTy) {
std::string s = getSimplyMangledType(t); std::string s = getSimplyMangledType(t);
if(s.size() > 0) if (s.size() > 0)
ret = ret + sep + getSimplyMangledType(t); ret = ret + sep + getSimplyMangledType(t);
} }
ret += "_out"; ret += "_out";
@ -199,20 +198,17 @@ static std::string getSimplyMangledFuncName(std::string prefix,
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults()); return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
} }
std::string getMangledFuncName(std::string prefix, std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
FunctionType fnTy) {
return getSimplyMangledFuncName(prefix, fnTy); return getSimplyMangledFuncName(prefix, fnTy);
} }
std::string getMangledFuncName(std::string prefix, std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
ArrayRef<Type> opTys,
ArrayRef<Type> retTys) { ArrayRef<Type> retTys) {
return getSimplyMangledFuncName(prefix, opTys, retTys); return getSimplyMangledFuncName(prefix, opTys, retTys);
} }
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName, static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
ArrayRef<Value> operands, ArrayRef<Value> operands, ArrayRef<Type> retTys) {
ArrayRef<Type> retTys) {
Builder builder(module); Builder builder(module);
SmallVector<Type, 8> tys; SmallVector<Type, 8> tys;
@ -242,8 +238,8 @@ static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
class AddOpConversion_affine : public ConversionPattern { class AddOpConversion_affine : public ConversionPattern {
public: public:
explicit AddOpConversion_affine(MLIRContext *context) explicit AddOpConversion_affine(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -310,18 +306,14 @@ public:
} }
}; };
// Replace the given operation with a call to the given function. // Replace the given operation with a call to the given function.
// The function is assumed to accept memrefs and scalar types and return // The function is assumed to accept memrefs and scalar types and return
// Memrefs. Here the result types are converted back to the result types of op, // Memrefs. Here the result types are converted back to the result types of op,
// but operands are NOT converted. This allows non-standard mappings from // but operands are NOT converted. This allows non-standard mappings from
// operand types to function types. // operand types to function types.
LogicalResult LogicalResult rewriteWithVoidFunctionCallExplicit(
rewriteWithVoidFunctionCallExplicit(Operation *op, Operation *op, ArrayRef<Value> callops, ArrayRef<Value> operands,
ArrayRef<Value> callops, ConversionPatternRewriter &rewriter, std::string functionName) {
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
std::string functionName) {
auto loc = op->getLoc(); auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc); edsc::ScopedContext scope(rewriter, loc);
@ -330,7 +322,7 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
SmallVector<Type, 8> opTys; SmallVector<Type, 8> opTys;
// Shape erased versions of the original operation types. // Shape erased versions of the original operation types.
SmallVector<Type, 8> erasedOpTys; SmallVector<Type, 8> erasedOpTys;
for (const Value &o: callops) { for (const Value &o : callops) {
Type t = o.getType(); Type t = o.getType();
opTys.push_back(t); opTys.push_back(t);
if (t.isa<MemRefType>()) if (t.isa<MemRefType>())
@ -347,21 +339,20 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
// Erased version of the return type. This is the return types of the // Erased version of the return type. This is the return types of the
// generated function call. // generated function call.
SmallVector<Type, 8> erasedRetTys; SmallVector<Type, 8> erasedRetTys;
for (const auto &o: op->getResults()) { for (const auto &o : op->getResults()) {
Type t = o.getType(); Type t = o.getType();
if (t.isa<TensorType>()) { if (t.isa<TensorType>()) {
TensorType tensorResultTy = t.cast<TensorType>(); TensorType tensorResultTy = t.cast<TensorType>();
MemRefType memRefResultTy = MemRefType memRefResultTy = mlir::MemRefType::get(
mlir::MemRefType::get(tensorResultTy.getShape(), tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
tensorResultTy.getElementType(), {}, 0); MemRefType erasedMemRefResultTy =
MemRefType erasedMemRefResultTy = getShapeErasedMemRefType(memRefResultTy); getShapeErasedMemRefType(memRefResultTy);
retTys.push_back(memRefResultTy); retTys.push_back(memRefResultTy);
// assume memRefResultTy has known shape, so we don't need any // assume memRefResultTy has known shape, so we don't need any
// dynamic dimensions for the alloc. // dynamic dimensions for the alloc.
assert(memRefResultTy.hasStaticShape()); assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal); Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal); newOps.push_back(castVal);
newResults.push_back(allocVal); newResults.push_back(allocVal);
@ -371,14 +362,13 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
} }
SmallVector<Type, 8> empty; SmallVector<Type, 8> empty;
std::string mangledFunctionName = getMangledFuncName(functionName, opTys, retTys); std::string mangledFunctionName =
getMangledFuncName(functionName, opTys, retTys);
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(), FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
mangledFunctionName, mangledFunctionName, newOps, empty);
newOps,
empty);
auto new_call = callOperation(empty, auto new_call =
rewriter.getSymbolRefAttr(funcOp), newOps); callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
rewriter.replaceOp(op, newResults); rewriter.replaceOp(op, newResults);
return success(); return success();
@ -389,8 +379,7 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted // Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
// appropriately. The called function passes results of the original function // appropriately. The called function passes results of the original function
// as memref arguments at the end of the original set of operands. // as memref arguments at the end of the original set of operands.
LogicalResult LogicalResult rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
std::string functionName) { std::string functionName) {
auto loc = op->getLoc(); auto loc = op->getLoc();
@ -398,7 +387,7 @@ rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
// Convert the arguments to the original call. // Convert the arguments to the original call.
SmallVector<Value, 8> callops; SmallVector<Value, 8> callops;
for (auto &o: operands) { for (auto &o : operands) {
Type t = o.getType(); Type t = o.getType();
if (t.isa<MemRefType>()) { if (t.isa<MemRefType>()) {
// Cast it to some memref type that we accept // Cast it to some memref type that we accept
@ -421,22 +410,22 @@ rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
return failure(); return failure();
} }
} }
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, functionName); return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
functionName);
} }
/// Lower Add /// Lower Add
template<typename Op> template <typename Op>
class ATenFunctionCallConversion : public ConversionPattern { class ATenFunctionCallConversion : public ConversionPattern {
public: public:
explicit ATenFunctionCallConversion(MLIRContext *context) explicit ATenFunctionCallConversion(MLIRContext *context)
: ConversionPattern(Op::getOperationName(), 1, context) { : ConversionPattern(Op::getOperationName(), 1, context) {}
}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, Op::getFunctionConversionName()); return rewriteWithFunctionCall(op, operands, rewriter,
Op::getFunctionConversionName());
} }
}; };
@ -444,8 +433,8 @@ public:
class ConstantOpConversion : public ConversionPattern { class ConstantOpConversion : public ConversionPattern {
public: public:
explicit ConstantOpConversion(MLIRContext *context) explicit ConstantOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -459,14 +448,15 @@ public:
Type t = result.getType(); Type t = result.getType();
if (t.isa<IntegerType>()) { if (t.isa<IntegerType>()) {
auto it = t.cast<IntegerType>(); auto it = t.cast<IntegerType>();
if(it.getWidth() > 1) { if (it.getWidth() > 1) {
auto a = op->getAttrOfType<IntegerAttr>("value"); auto a = op->getAttrOfType<IntegerAttr>("value");
SmallVector<Value, 8> newValues {rewriter.create<mlir::ConstantOp>(loc, a)}; SmallVector<Value, 8> newValues{
rewriter.create<mlir::ConstantOp>(loc, a)};
rewriter.replaceOp(op, newValues); rewriter.replaceOp(op, newValues);
return success(); return success();
} else { } else {
auto a = op->getAttrOfType<BoolAttr>("value"); auto a = op->getAttrOfType<BoolAttr>("value");
SmallVector<Value, 8> newValues {constInt(a.getValue(), it.getWidth())}; SmallVector<Value, 8> newValues{constInt(a.getValue(), it.getWidth())};
rewriter.replaceOp(op, newValues); rewriter.replaceOp(op, newValues);
return success(); return success();
} }
@ -485,8 +475,8 @@ public:
class AddOpConversion : public ConversionPattern { class AddOpConversion : public ConversionPattern {
public: public:
explicit AddOpConversion(MLIRContext *context) explicit AddOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -513,8 +503,8 @@ public:
class AsStridedOpConversion : public ConversionPattern { class AsStridedOpConversion : public ConversionPattern {
public: public:
explicit AsStridedOpConversion(MLIRContext *context) explicit AsStridedOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -527,7 +517,8 @@ public:
// construct the shape argument // construct the shape argument
std::vector<Value> shape; std::vector<Value> shape;
std::vector<int64_t> result_shape; std::vector<int64_t> result_shape;
auto co0 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp()); auto co0 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a0 = DenseElementsAttr a0 =
co0.template getAttrOfType<DenseElementsAttr>("value"); co0.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a0.getAttributeValues()) for (auto i : a0.getAttributeValues())
@ -539,7 +530,8 @@ public:
// construct the stride argument // construct the stride argument
std::vector<Value> stride; std::vector<Value> stride;
auto co1 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp()); auto co1 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
DenseElementsAttr a1 = DenseElementsAttr a1 =
co1.template getAttrOfType<DenseElementsAttr>("value"); co1.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a1.getAttributeValues()) for (auto i : a1.getAttributeValues())
@ -551,19 +543,21 @@ public:
APInt offset(32, 0); APInt offset(32, 0);
if (operands.size() > 3) { if (operands.size() > 3) {
auto co2 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp()); auto co2 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
auto ia2 = co2.getAttrOfType<IntegerAttr>("value"); auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
offset = ia2.getValue(); offset = ia2.getValue();
} }
SmallVector<Value, 8> callops{xVal, shape[0], SmallVector<Value, 8> callops{
xVal, shape[0],
shape[1], shape[2], shape[1], shape[2],
shape[3], stride[0], shape[3], stride[0],
stride[1], stride[2], stride[1], stride[2],
stride[3], constInt(offset.getSExtValue(), 32)}; stride[3], constInt(offset.getSExtValue(), 32)};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "as_strided"); "as_strided");
} }
}; };
@ -571,8 +565,8 @@ public:
class BatchNormOpConversion : public ConversionPattern { class BatchNormOpConversion : public ConversionPattern {
public: public:
explicit BatchNormOpConversion(MLIRContext *context) explicit BatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -585,8 +579,8 @@ public:
class ConvolutionOpConversion : public ConversionPattern { class ConvolutionOpConversion : public ConversionPattern {
public: public:
explicit ConvolutionOpConversion(MLIRContext *context) explicit ConvolutionOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -614,8 +608,8 @@ public:
class DivOpConversion : public ConversionPattern { class DivOpConversion : public ConversionPattern {
public: public:
explicit DivOpConversion(MLIRContext *context) explicit DivOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -627,8 +621,8 @@ public:
class LogSoftmaxOpConversion : public ConversionPattern { class LogSoftmaxOpConversion : public ConversionPattern {
public: public:
explicit LogSoftmaxOpConversion(MLIRContext *context) explicit LogSoftmaxOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -657,8 +651,8 @@ public:
class MaxPoolOpConversion : public ConversionPattern { class MaxPoolOpConversion : public ConversionPattern {
public: public:
explicit MaxPoolOpConversion(MLIRContext *context) explicit MaxPoolOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(), 1, : ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -678,7 +672,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices"); return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices");
} }
}; };
@ -686,14 +681,15 @@ public:
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern { class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
public: public:
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context) explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
: ConversionPattern( : ConversionPattern(mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::
mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::getOperationName(), 1, getOperationName(),
context) {} 1, context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices_backward");
} }
}; };
@ -701,7 +697,8 @@ public:
class MMOpConversion : public ConversionPattern { class MMOpConversion : public ConversionPattern {
public: public:
explicit MMOpConversion(MLIRContext *context) explicit MMOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1, context) {} : ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -714,8 +711,8 @@ public:
class MulOpConversion : public ConversionPattern { class MulOpConversion : public ConversionPattern {
public: public:
explicit MulOpConversion(MLIRContext *context) explicit MulOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1, context) { : ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1,
} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -728,8 +725,9 @@ public:
class NativeBatchNormOpConversion : public ConversionPattern { class NativeBatchNormOpConversion : public ConversionPattern {
public: public:
explicit NativeBatchNormOpConversion(MLIRContext *context) explicit NativeBatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -742,13 +740,15 @@ public:
class NllLoss2dBackwardOpConversion : public ConversionPattern { class NllLoss2dBackwardOpConversion : public ConversionPattern {
public: public:
explicit NllLoss2dBackwardOpConversion(MLIRContext *context) explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_backward");
} }
}; };
@ -756,13 +756,15 @@ public:
class NllLoss2dForwardOpConversion : public ConversionPattern { class NllLoss2dForwardOpConversion : public ConversionPattern {
public: public:
explicit NllLoss2dForwardOpConversion(MLIRContext *context) explicit NllLoss2dForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_forward"); return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_forward");
} }
}; };
@ -770,8 +772,9 @@ public:
class NllLossBackwardOpConversion : public ConversionPattern { class NllLossBackwardOpConversion : public ConversionPattern {
public: public:
explicit NllLossBackwardOpConversion(MLIRContext *context) explicit NllLossBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -784,13 +787,15 @@ public:
class NllLossForwardOpConversion : public ConversionPattern { class NllLossForwardOpConversion : public ConversionPattern {
public: public:
explicit NllLossForwardOpConversion(MLIRContext *context) explicit NllLossForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1, : ConversionPattern(
mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
context) {} context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward"); } return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward");
}
}; };
/// Lower ReLU /// Lower ReLU
@ -811,13 +816,15 @@ public:
class ThresholdBackwardOpConversion : public ConversionPattern { class ThresholdBackwardOpConversion : public ConversionPattern {
public: public:
explicit ThresholdBackwardOpConversion(MLIRContext *context) explicit ThresholdBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), : ConversionPattern(
1, context) {} mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "threshold_backward"); return rewriteWithFunctionCall(op, operands, rewriter,
"threshold_backward");
} }
}; };
@ -825,7 +832,8 @@ public:
class TransposeOpConversion : public ConversionPattern { class TransposeOpConversion : public ConversionPattern {
public: public:
explicit TransposeOpConversion(MLIRContext *context) explicit TransposeOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1, context) {} : ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1,
context) {}
LogicalResult LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -851,19 +859,20 @@ public:
// construct the shape argument // construct the shape argument
SmallVector<Value, 8> shape; SmallVector<Value, 8> shape;
auto co = dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp()); auto co =
dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value"); DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getAttributeValues()) for (auto i : a.getAttributeValues())
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i)); shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
// pad out the shape with -1 to make it 4d // pad out the shape with -1 to make it 4d
while (shape.size() < 4) while (shape.size() < 4)
shape.push_back(constInt(-1, 32)); shape.push_back(constInt(-1, 32));
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]}; SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "view"); return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
"view");
} }
}; };
@ -896,9 +905,8 @@ struct ATenLoweringPass
// c++ patterns // c++ patterns
acapPatterns.insert< acapPatterns.insert<
ConstantOpConversion, ConstantOpConversion, AddOpConversion, ConvolutionOpConversion,
AddOpConversion, ConvolutionOpConversion, ReLUOpConversion, ReLUOpConversion, TransposeOpConversion, BatchNormOpConversion,
TransposeOpConversion, BatchNormOpConversion,
NativeBatchNormOpConversion, MaxPoolOpConversion, NativeBatchNormOpConversion, MaxPoolOpConversion,
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion, MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
MulOpConversion, MMOpConversion, AsStridedOpConversion, MulOpConversion, MMOpConversion, AsStridedOpConversion,

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/ATen/ATenToStd.h" #include "npcomp/Dialect/ATen/ATenToStd.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;

View File

@ -115,8 +115,8 @@ std::string LivenessReport::emitJSONReport() {
for (auto v : vlist) { for (auto v : vlist) {
int64_t vol = getTensorVolume(v.getType()); int64_t vol = getTensorVolume(v.getType());
if (v.getDefiningOp()) { if (v.getDefiningOp()) {
if (auto a = v.getDefiningOp()->getAttrOfType<StringAttr>( if (auto a =
"layer_name")) { v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
auto definingOp = v.getDefiningOp(); auto definingOp = v.getDefiningOp();
auto ld = layerDetail.getInteger(a.getValue().str()); auto ld = layerDetail.getInteger(a.getValue().str());
if (ld) if (ld)

View File

@ -32,7 +32,8 @@ public:
auto op = opBuilder.create<scf::YieldOp>(loc, yields); auto op = opBuilder.create<scf::YieldOp>(loc, yields);
return op.getOperation(); return op.getOperation();
}) })
.def("scf_if_op", .def(
"scf_if_op",
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes, [](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
PyValue cond, bool withElseRegion) { PyValue cond, bool withElseRegion) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);

View File

@ -65,7 +65,8 @@ void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
"front", "front",
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); }) [](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
.def("__len__", [](ThisTy &self) { return self.list.size(); }) .def("__len__", [](ThisTy &self) { return self.list.size(); })
.def("__iter__", .def(
"__iter__",
[](ThisTy &self) { [](ThisTy &self) {
PyItemIterator begin(self.list.begin()); PyItemIterator begin(self.list.begin());
PyItemIterator end(self.list.end()); PyItemIterator end(self.list.end());
@ -194,10 +195,10 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> { [](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
return self.context.shared_from_this(); return self.context.shared_from_this();
}) })
.def("op", .def(
"op",
[](PyDialectHelper &self, const std::string &opNameStr, [](PyDialectHelper &self, const std::string &opNameStr,
std::vector<PyType> pyResultTypes, std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef { llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = self.pyOpBuilder.getCurrentLoc(); Location loc = self.pyOpBuilder.getCurrentLoc();
@ -222,7 +223,8 @@ void PyDialectHelper::bind(py::module m) {
}, },
py::arg("op_name"), py::arg("result_types"), py::arg("operands"), py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>()) py::arg("attrs") = llvm::Optional<PyAttribute>())
.def("func_op", .def(
"func_op",
[](PyDialectHelper &self, const std::string &name, PyType type, [](PyDialectHelper &self, const std::string &name, PyType type,
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) { bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
auto functionType = type.type.dyn_cast_or_null<FunctionType>(); auto functionType = type.type.dyn_cast_or_null<FunctionType>();
@ -258,7 +260,8 @@ void PyDialectHelper::bind(py::module m) {
R"(Creates a new `func` op, optionally creating an entry block. R"(Creates a new `func` op, optionally creating an entry block.
If an entry block is created, the builder will be positioned If an entry block is created, the builder will be positioned
to its start.)") to its start.)")
.def("select_op", .def(
"select_op",
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue, [](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
PyValue falseValue) -> PyOperationRef { PyValue falseValue) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
@ -288,7 +291,8 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> PyType { [](PyDialectHelper &self) -> PyType {
return IndexType::get(self.getContext()); return IndexType::get(self.getContext());
}) })
.def("integer_type", .def(
"integer_type",
[](PyDialectHelper &self, unsigned width) -> PyType { [](PyDialectHelper &self, unsigned width) -> PyType {
return IntegerType::get(width, self.getContext()); return IntegerType::get(width, self.getContext());
}, },
@ -319,7 +323,8 @@ void PyDialectHelper::bind(py::module m) {
return FloatType::get(StandardTypes::F64, return FloatType::get(StandardTypes::F64,
self.getContext()); self.getContext());
}) })
.def("tensor_type", .def(
"tensor_type",
[](PyDialectHelper &self, PyType elementType, [](PyDialectHelper &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) -> PyType { llvm::Optional<std::vector<int64_t>> shape) -> PyType {
if (!elementType.type) { if (!elementType.type) {
@ -367,17 +372,20 @@ void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library"; m.doc() = "Python bindings for constructs in the mlir/IR library";
// Globals. // Globals.
m.def("emit_error", m.def(
"emit_error",
[](PyAttribute loc, std::string message) { [](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Error, loc, message); emitDiagnostic(DiagnosticSeverity::Error, loc, message);
}, },
py::arg("loc"), py::arg("message")); py::arg("loc"), py::arg("message"));
m.def("emit_warning", m.def(
"emit_warning",
[](PyAttribute loc, std::string message) { [](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Warning, loc, message); emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
}, },
py::arg("loc"), py::arg("message")); py::arg("loc"), py::arg("message"));
m.def("emit_remark", m.def(
"emit_remark",
[](PyAttribute loc, std::string message) { [](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Remark, loc, message); emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
}, },
@ -426,7 +434,8 @@ void PyContext::bind(py::module m) {
return PyModuleOp(self.shared_from_this(), m); return PyModuleOp(self.shared_from_this(), m);
}) })
.def("parse_asm", &PyContext::parseAsm) .def("parse_asm", &PyContext::parseAsm)
.def("new_builder", .def(
"new_builder",
[](PyContext &self) { [](PyContext &self) {
// Note: we collapse the Builder and OpBuilder into one because // Note: we collapse the Builder and OpBuilder into one because
// there is little reason to expose the inheritance hierarchy to // there is little reason to expose the inheritance hierarchy to
@ -438,7 +447,8 @@ void PyContext::bind(py::module m) {
[](PyContext &self, std::string s) -> PyIdentifier { [](PyContext &self, std::string s) -> PyIdentifier {
return Identifier::get(s, &self.context); return Identifier::get(s, &self.context);
}) })
.def("file_line_col_loc_attr", .def(
"file_line_col_loc_attr",
[](PyContext &self, PyIdentifier filename, unsigned line, [](PyContext &self, PyIdentifier filename, unsigned line,
unsigned column) -> PyAttribute { unsigned column) -> PyAttribute {
return static_cast<LocationAttr>(FileLineColLoc::get( return static_cast<LocationAttr>(FileLineColLoc::get(
@ -456,7 +466,8 @@ void PyContext::bind(py::module m) {
} }
return PyType(t); return PyType(t);
}) })
.def("integer_attr", .def(
"integer_attr",
[](PyContext &self, PyType type, int64_t value) -> PyAttribute { [](PyContext &self, PyType type, int64_t value) -> PyAttribute {
if (!type.type.isa<IntegerType>()) { if (!type.type.isa<IntegerType>()) {
throw py::raiseValueError("Expected IntegerType"); throw py::raiseValueError("Expected IntegerType");
@ -512,7 +523,8 @@ void PyContext::bind(py::module m) {
} }
return ArrayAttr::get(attrs, &self.context); return ArrayAttr::get(attrs, &self.context);
}) })
.def("dense_elements_attr", .def(
"dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute { [](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view. // Request a contiguous view.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
@ -522,8 +534,7 @@ void PyContext::bind(py::module m) {
throw py::error_already_set(); throw py::error_already_set();
} }
py::buffer_info array_info(view); py::buffer_info array_info(view);
return createDenseElementsAttrFromBuffer(&self.context, return createDenseElementsAttrFromBuffer(&self.context, array_info);
array_info);
}, },
py::arg("array")) py::arg("array"))
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute { .def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
@ -905,26 +916,28 @@ void PyBaseOpBuilder::bind(py::module m) {
void PyOpBuilder::bind(py::module m) { void PyOpBuilder::bind(py::module m) {
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder") py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
.def(py::init<PyContext &>(), py::keep_alive<1, 2>()) .def(py::init<PyContext &>(), py::keep_alive<1, 2>())
.def_property("current_loc", .def_property(
"current_loc",
[](PyOpBuilder &self) -> PyAttribute { [](PyOpBuilder &self) -> PyAttribute {
return static_cast<Attribute>(self.getCurrentLoc()); return static_cast<Attribute>(self.getCurrentLoc());
}, },
[](PyOpBuilder &self, PyAttribute attr) { [](PyOpBuilder &self, PyAttribute attr) {
auto loc_attr = auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
attr.attr.dyn_cast_or_null<LocationAttr>();
if (!loc_attr) { if (!loc_attr) {
throw py::raiseValueError("Expected a LocationAttr"); throw py::raiseValueError("Expected a LocationAttr");
} }
self.setCurrentLoc(Location(loc_attr)); self.setCurrentLoc(Location(loc_attr));
}) })
.def_property("insertion_point", .def_property(
"insertion_point",
[](PyOpBuilder &self) { [](PyOpBuilder &self) {
return self.getBuilder(true).saveInsertionPoint(); return self.getBuilder(true).saveInsertionPoint();
}, },
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) { [](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
self.getBuilder(false).restoreInsertionPoint(ip); self.getBuilder(false).restoreInsertionPoint(ip);
}) })
.def("set_file_line_col", .def(
"set_file_line_col",
[](PyOpBuilder &self, PyIdentifier filename, unsigned line, [](PyOpBuilder &self, PyIdentifier filename, unsigned line,
unsigned column) { unsigned column) {
Location loc = FileLineColLoc::get(filename.identifier, line, Location loc = FileLineColLoc::get(filename.identifier, line,
@ -935,29 +948,34 @@ void PyOpBuilder::bind(py::module m) {
"Shortcut to set a FileLineCol current location") "Shortcut to set a FileLineCol current location")
.def("clear_insertion_point", .def("clear_insertion_point",
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); }) [](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
.def("insert_op_before", .def(
"insert_op_before",
[](PyOpBuilder &self, PyBaseOperation &pyOp) { [](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation(); Operation *op = pyOp.getOperation();
self.builder.setInsertionPoint(op); self.builder.setInsertionPoint(op);
}, },
"Sets the insertion point to just before the specified op.") "Sets the insertion point to just before the specified op.")
.def("insert_op_after", .def(
"insert_op_after",
[](PyOpBuilder &self, PyBaseOperation &pyOp) { [](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation(); Operation *op = pyOp.getOperation();
self.builder.setInsertionPointAfter(op); self.builder.setInsertionPointAfter(op);
}, },
"Sets the insertion point to just after the specified op.") "Sets the insertion point to just after the specified op.")
.def("insert_block_start", .def(
"insert_block_start",
[](PyOpBuilder &self, PyBlockRef block) { [](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToStart(&block.block); self.builder.setInsertionPointToStart(&block.block);
}, },
"Sets the insertion point to the start of the block.") "Sets the insertion point to the start of the block.")
.def("insert_block_end", .def(
"insert_block_end",
[](PyOpBuilder &self, PyBlockRef block) { [](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToEnd(&block.block); self.builder.setInsertionPointToEnd(&block.block);
}, },
"Sets the insertion point to the end of the block.") "Sets the insertion point to the end of the block.")
.def("insert_before_terminator", .def(
"insert_before_terminator",
[](PyOpBuilder &self, PyBlockRef block) { [](PyOpBuilder &self, PyBlockRef block) {
auto *terminator = block.block.getTerminator(); auto *terminator = block.block.getTerminator();
if (!terminator) { if (!terminator) {

View File

@ -32,7 +32,8 @@ void PyPassManager::bind(py::module m) {
py::class_<PyPassManager>(m, "PassManager") py::class_<PyPassManager>(m, "PassManager")
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"), .def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
py::arg("verifyModules") = true) py::arg("verifyModules") = true)
.def("enableCrashReproducerGeneration", .def(
"enableCrashReproducerGeneration",
[](PyPassManager &self, std::string outputFile, [](PyPassManager &self, std::string outputFile,
bool genLocalReproducer) { bool genLocalReproducer) {
self.passManager.enableCrashReproducerGeneration( self.passManager.enableCrashReproducerGeneration(

View File

@ -48,7 +48,8 @@ public:
return Basicpy::NoneType::get( return Basicpy::NoneType::get(
self.getContext()); self.getContext());
}) })
.def("basicpy_SlotObject_type", .def(
"basicpy_SlotObject_type",
[](BasicpyDialectHelper &self, std::string className, [](BasicpyDialectHelper &self, std::string className,
py::args pySlotTypes) -> PyType { py::args pySlotTypes) -> PyType {
SmallVector<Type, 4> slotTypes; SmallVector<Type, 4> slotTypes;