Format sources.

pull/35/head
Stella Laurenzo 2020-08-27 14:47:49 -07:00
parent de38caa547
commit fc4f374345
13 changed files with 537 additions and 502 deletions

View File

@ -14,7 +14,7 @@
namespace mlir {
namespace NPCOMP {
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
} // namespace aten
} // namespace NPCOMP
} // namespace mlir
#endif

View File

@ -11,9 +11,9 @@
#include "npcomp/Dialect/ATen/ATenDialect.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "aten-op-stats"
@ -25,7 +25,8 @@ namespace NPCOMP {
namespace aten {
// Return the op statistics for conv2d-like operations.
template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
template <class T>
std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
std::map<std::string, uint64_t> toReturn;
@ -66,11 +67,13 @@ template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uin
}
// Return the op statistics for conv2dBackward-like operations.
template<typename T>
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t groups) {
template <typename T>
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op,
uint64_t groups) {
std::map<std::string, uint64_t> toReturn;
TensorType dx_out_resultTy = op.getResult(0).getType().template cast<TensorType>();
TensorType dx_out_resultTy =
op.getResult(0).getType().template cast<TensorType>();
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
@ -79,7 +82,7 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
uint64_t kernel_width = weightTy.getShape()[2];
uint64_t kernel_height = weightTy.getShape()[3];
uint64_t MACs_per_loss =
uint64_t MACs_per_loss =
(loss_in_depth / groups) * kernel_height * kernel_width;
uint64_t total_MACs = dx_out_volume * MACs_per_loss;
@ -119,7 +122,6 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
return toReturn;
}
// Return a model of the number of bytes needed to represent the operand of
// the given convolution-like operation with the given index. The shape is
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
@ -222,8 +224,7 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
}
// Return the op statistics for matrixmultiply-like operations.
template<typename T>
std::map<std::string, uint64_t> getMMOpStatistics(T op) {
template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> toReturn;

View File

@ -20,7 +20,7 @@ namespace aten {
// #define GEN_PASS_CLASSES
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
void registerATenPasses();
void registerATenPasses();
} // namespace aten
} // namespace NPCOMP
} // namespace mlir

View File

@ -44,33 +44,36 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
);
});
m.def("build_flow_transform_pass_pipeline",
[](PyPassManager &pm) {
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
pm.passManager);
},
py::arg("pm"),
py::doc("Builds a pass pipeline for top-level Flow import"));
m.def("build_hal_transform_pass_pipeline",
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
mlir::iree_compiler::IREE::HAL::TargetOptions options;
if (targetBackends.empty()) {
options.targets =
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
} else {
options.targets = std::move(targetBackends);
}
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(
pm.passManager, options);
},
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
py::doc("Builds a pass pipeline for top-level Flow import"));
m.def("build_vm_transform_pass_pipeline",
[](PyPassManager &pm) {
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
pm.passManager);
},
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
m.def(
"build_flow_transform_pass_pipeline",
[](PyPassManager &pm) {
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
pm.passManager);
},
py::arg("pm"),
py::doc("Builds a pass pipeline for top-level Flow import"));
m.def(
"build_hal_transform_pass_pipeline",
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
mlir::iree_compiler::IREE::HAL::TargetOptions options;
if (targetBackends.empty()) {
options.targets =
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
} else {
options.targets = std::move(targetBackends);
}
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(pm.passManager,
options);
},
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
py::doc("Builds a pass pipeline for top-level Flow import"));
m.def(
"build_vm_transform_pass_pipeline",
[](PyPassManager &pm) {
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
pm.passManager);
},
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
m.def("translate_to_vm_bytecode", [](PyModuleOp &module) {
// TODO: Make the options parameterizable.
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options;

View File

@ -89,38 +89,39 @@ void npcomp::python::defineBackendRefJitModule(py::module m) {
JITModule::buildBackendCompilationPipeline(pm.passManager);
});
py::class_<JITModule>(m, "JITModule")
.def_static("from_compiled_module",
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
-> std::unique_ptr<JITModule> {
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
pySharedLibs.end());
auto jitModule =
checkError(JITModule::fromCompiledModule(
module.moduleOp, sharedLibs),
"error creating JITModule: ");
return jitModule;
},
py::arg("module"), py::arg("shared_libs"))
.def("invoke",
[](JITModule &self, std::string functionName,
std::vector<py::buffer> inputs) {
// Prepare inputs.
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
inputTensors.reserve(inputs.size());
for (py::buffer &inputBuffer : inputs) {
inputTensors.push_back(copyBufferToTensor(inputBuffer));
}
.def_static(
"from_compiled_module",
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
-> std::unique_ptr<JITModule> {
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
pySharedLibs.end());
auto jitModule = checkError(
JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
"error creating JITModule: ");
return jitModule;
},
py::arg("module"), py::arg("shared_libs"))
.def(
"invoke",
[](JITModule &self, std::string functionName,
std::vector<py::buffer> inputs) {
// Prepare inputs.
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
inputTensors.reserve(inputs.size());
for (py::buffer &inputBuffer : inputs) {
inputTensors.push_back(copyBufferToTensor(inputBuffer));
}
auto outputs = checkError(self.invoke(functionName, inputTensors),
"error invoking JIT function: ");
std::vector<py::array> outputArrays;
outputArrays.reserve(outputs.size());
for (Ref<Tensor> &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor));
}
return outputArrays;
},
py::arg("function_name"), py::arg("inputs"));
auto outputs = checkError(self.invoke(functionName, inputTensors),
"error invoking JIT function: ");
std::vector<py::array> outputArrays;
outputArrays.reserve(outputs.size());
for (Ref<Tensor> &outputTensor : outputs) {
outputArrays.push_back(wrapTensorAsArray(outputTensor));
}
return outputArrays;
},
py::arg("function_name"), py::arg("inputs"));
// A Ref<Tensor> needs to be bound because we use it as a base for the
// ndarray (the array retains a reference to it). Users should not encounter

View File

@ -55,7 +55,7 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
return toReturn;
}
// add
// add
std::map<std::string, uint64_t> AddOp::getStatistics() {
std::map<std::string, uint64_t> toReturn;
@ -222,7 +222,8 @@ uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
return getConv2dBackwardStatistics(*this, 1);
}
std::map<std::string, uint64_t> ConvolutionBackwardOverrideableOp::getStatistics() {
std::map<std::string, uint64_t>
ConvolutionBackwardOverrideableOp::getStatistics() {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
auto ia = co.template getAttrOfType<IntegerAttr>("value");
uint64_t groups = ia.getValue().getZExtValue();
@ -463,7 +464,7 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
// getMMOpStatistics(*this);
// }
std::map<std::string, uint64_t> MmOp::getStatistics() {
return getMMOpStatistics(*this );
return getMMOpStatistics(*this);
}
// mul

View File

@ -61,7 +61,8 @@ namespace {
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
if (val.getType() == destTy)
return val;
return builder.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
return builder
.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
.getResult();
}
@ -69,11 +70,11 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
/// unknown shape.
static MemRefType getShapeErasedMemRefType(MemRefType type) {
std::vector<int64_t> shape = type.getShape();
for(int i = 0; i < shape.size(); i++) {
for (int i = 0; i < shape.size(); i++) {
shape[i] = -1;
}
return MemRefType::get(shape, type.getElementType(),
type.getAffineMaps(), type.getMemorySpace());
return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
type.getMemorySpace());
}
/// Create a type cast to memref
@ -82,14 +83,12 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
MemRefType newType = getShapeErasedMemRefType(memrefTy);
return builder.create<MemRefCastOp>(val.getLoc(),
val, newType)
.getResult();
return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
}
if (auto tensorTy = type.dyn_cast<TensorType>()) {
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
tensorTy.getElementType(), {}, 0);
return typeCast(builder, val, memRefType);
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
tensorTy.getElementType(), {}, 0);
return typeCast(builder, val, memRefType);
}
return val;
}
@ -186,7 +185,7 @@ static std::string getSimplyMangledFuncName(std::string prefix,
ret = ret + sep + getSimplyMangledType(t);
for (const Type t : operTy) {
std::string s = getSimplyMangledType(t);
if(s.size() > 0)
if (s.size() > 0)
ret = ret + sep + getSimplyMangledType(t);
}
ret += "_out";
@ -194,25 +193,22 @@ static std::string getSimplyMangledFuncName(std::string prefix,
return ret;
}
static std::string getSimplyMangledFuncName(std::string prefix,
FunctionType fnTy) {
FunctionType fnTy) {
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
}
std::string getMangledFuncName(std::string prefix,
FunctionType fnTy) {
std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
return getSimplyMangledFuncName(prefix, fnTy);
}
std::string getMangledFuncName(std::string prefix,
ArrayRef<Type> opTys,
std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
ArrayRef<Type> retTys) {
return getSimplyMangledFuncName(prefix, opTys, retTys);
}
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
ArrayRef<Value> operands,
ArrayRef<Type> retTys) {
ArrayRef<Value> operands, ArrayRef<Type> retTys) {
Builder builder(module);
SmallVector<Type, 8> tys;
@ -242,8 +238,8 @@ static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
class AddOpConversion_affine : public ConversionPattern {
public:
explicit AddOpConversion_affine(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
}
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -310,78 +306,72 @@ public:
}
};
// Replace the given operation with a call to the given function.
// The function is assumed to accept memrefs and scalar types and return
// Memrefs. Here the result types are converted back to the result types of op,
// but operands are NOT converted. This allows non-standard mappings from
// operand types to function types.
LogicalResult
rewriteWithVoidFunctionCallExplicit(Operation *op,
ArrayRef<Value> callops,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
std::string functionName) {
LogicalResult rewriteWithVoidFunctionCallExplicit(
Operation *op, ArrayRef<Value> callops, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter, std::string functionName) {
auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc);
auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc);
// The original operation types.
SmallVector<Type, 8> opTys;
// Shape erased versions of the original operation types.
SmallVector<Type, 8> erasedOpTys;
for (const Value &o: callops) {
Type t = o.getType();
opTys.push_back(t);
if (t.isa<MemRefType>())
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
else
erasedOpTys.push_back(t);
// The original operation types.
SmallVector<Type, 8> opTys;
// Shape erased versions of the original operation types.
SmallVector<Type, 8> erasedOpTys;
for (const Value &o : callops) {
Type t = o.getType();
opTys.push_back(t);
if (t.isa<MemRefType>())
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
else
erasedOpTys.push_back(t);
}
std::vector<Value> newOps = callops;
SmallVector<Value, 8> newResults;
// Result types of the original operation, converted to memrefs.
SmallVector<Type, 8> retTys;
// Erased version of the return type. This is the return types of the
// generated function call.
SmallVector<Type, 8> erasedRetTys;
for (const auto &o : op->getResults()) {
Type t = o.getType();
if (t.isa<TensorType>()) {
TensorType tensorResultTy = t.cast<TensorType>();
MemRefType memRefResultTy = mlir::MemRefType::get(
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
MemRefType erasedMemRefResultTy =
getShapeErasedMemRefType(memRefResultTy);
retTys.push_back(memRefResultTy);
// assume memRefResultTy has known shape, so we don't need any
// dynamic dimensions for the alloc.
assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal);
newResults.push_back(allocVal);
} else {
return failure();
}
}
std::vector<Value> newOps = callops;
SmallVector<Value, 8> newResults;
SmallVector<Type, 8> empty;
std::string mangledFunctionName =
getMangledFuncName(functionName, opTys, retTys);
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
mangledFunctionName, newOps, empty);
// Result types of the original operation, converted to memrefs.
SmallVector<Type, 8> retTys;
// Erased version of the return type. This is the return types of the
// generated function call.
SmallVector<Type, 8> erasedRetTys;
for (const auto &o: op->getResults()) {
Type t = o.getType();
if (t.isa<TensorType>()) {
TensorType tensorResultTy = t.cast<TensorType>();
MemRefType memRefResultTy =
mlir::MemRefType::get(tensorResultTy.getShape(),
tensorResultTy.getElementType(), {}, 0);
MemRefType erasedMemRefResultTy = getShapeErasedMemRefType(memRefResultTy);
retTys.push_back(memRefResultTy);
auto new_call =
callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
// assume memRefResultTy has known shape, so we don't need any
// dynamic dimensions for the alloc.
assert(memRefResultTy.hasStaticShape());
Value allocVal = rewriter.create<AllocOp>(op->getLoc(),
memRefResultTy);
Value castVal = memRefTypeCast(rewriter, allocVal);
newOps.push_back(castVal);
newResults.push_back(allocVal);
} else {
return failure();
}
}
SmallVector<Type, 8> empty;
std::string mangledFunctionName = getMangledFuncName(functionName, opTys, retTys);
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
mangledFunctionName,
newOps,
empty);
auto new_call = callOperation(empty,
rewriter.getSymbolRefAttr(funcOp), newOps);
rewriter.replaceOp(op, newResults);
return success();
rewriter.replaceOp(op, newResults);
return success();
}
// Replace the given operation with a call to the given function.
@ -389,54 +379,53 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
// appropriately. The called function passes results of the original function
// as memref arguments at the end of the original set of operands.
LogicalResult
rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
std::string functionName) {
auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc);
LogicalResult rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter,
std::string functionName) {
auto loc = op->getLoc();
edsc::ScopedContext scope(rewriter, loc);
// Convert the arguments to the original call.
SmallVector<Value, 8> callops;
for (auto &o: operands) {
Type t = o.getType();
if (t.isa<MemRefType>()) {
// Cast it to some memref type that we accept
callops.push_back(memRefTypeCast(rewriter, o));
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
callops.push_back(o);
} else if (t.isa<ATenListType>()) {
// FIXME: lots of assumptions here.
auto unpack = [](auto &op, auto &v) -> void {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
DenseElementsAttr a =
co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getIntValues())
v.push_back(i.getSExtValue());
};
std::vector<uint64_t> values;
unpack(o, values);
callops.push_back(constInt(values[0], 32));
} else {
return failure();
}
// Convert the arguments to the original call.
SmallVector<Value, 8> callops;
for (auto &o : operands) {
Type t = o.getType();
if (t.isa<MemRefType>()) {
// Cast it to some memref type that we accept
callops.push_back(memRefTypeCast(rewriter, o));
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
callops.push_back(o);
} else if (t.isa<ATenListType>()) {
// FIXME: lots of assumptions here.
auto unpack = [](auto &op, auto &v) -> void {
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
DenseElementsAttr a =
co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getIntValues())
v.push_back(i.getSExtValue());
};
std::vector<uint64_t> values;
unpack(o, values);
callops.push_back(constInt(values[0], 32));
} else {
return failure();
}
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, functionName);
}
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
functionName);
}
/// Lower Add
template<typename Op>
template <typename Op>
class ATenFunctionCallConversion : public ConversionPattern {
public:
explicit ATenFunctionCallConversion(MLIRContext *context)
: ConversionPattern(Op::getOperationName(), 1, context) {
}
: ConversionPattern(Op::getOperationName(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, Op::getFunctionConversionName());
return rewriteWithFunctionCall(op, operands, rewriter,
Op::getFunctionConversionName());
}
};
@ -444,8 +433,8 @@ public:
class ConstantOpConversion : public ConversionPattern {
public:
explicit ConstantOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1, context) {
}
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -459,14 +448,15 @@ public:
Type t = result.getType();
if (t.isa<IntegerType>()) {
auto it = t.cast<IntegerType>();
if(it.getWidth() > 1) {
if (it.getWidth() > 1) {
auto a = op->getAttrOfType<IntegerAttr>("value");
SmallVector<Value, 8> newValues {rewriter.create<mlir::ConstantOp>(loc, a)};
SmallVector<Value, 8> newValues{
rewriter.create<mlir::ConstantOp>(loc, a)};
rewriter.replaceOp(op, newValues);
return success();
} else {
auto a = op->getAttrOfType<BoolAttr>("value");
SmallVector<Value, 8> newValues {constInt(a.getValue(), it.getWidth())};
SmallVector<Value, 8> newValues{constInt(a.getValue(), it.getWidth())};
rewriter.replaceOp(op, newValues);
return success();
}
@ -485,8 +475,8 @@ public:
class AddOpConversion : public ConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
}
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -513,8 +503,8 @@ public:
class AsStridedOpConversion : public ConversionPattern {
public:
explicit AsStridedOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -527,7 +517,8 @@ public:
// construct the shape argument
std::vector<Value> shape;
std::vector<int64_t> result_shape;
auto co0 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
auto co0 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a0 =
co0.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a0.getAttributeValues())
@ -539,7 +530,8 @@ public:
// construct the stride argument
std::vector<Value> stride;
auto co1 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
auto co1 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
DenseElementsAttr a1 =
co1.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a1.getAttributeValues())
@ -551,19 +543,21 @@ public:
APInt offset(32, 0);
if (operands.size() > 3) {
auto co2 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
auto co2 =
cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
offset = ia2.getValue();
}
SmallVector<Value, 8> callops{xVal, shape[0],
shape[1], shape[2],
shape[3], stride[0],
stride[1], stride[2],
stride[3], constInt(offset.getSExtValue(), 32)};
SmallVector<Value, 8> callops{
xVal, shape[0],
shape[1], shape[2],
shape[3], stride[0],
stride[1], stride[2],
stride[3], constInt(offset.getSExtValue(), 32)};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "as_strided");
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
"as_strided");
}
};
@ -571,8 +565,8 @@ public:
class BatchNormOpConversion : public ConversionPattern {
public:
explicit BatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -585,8 +579,8 @@ public:
class ConvolutionOpConversion : public ConversionPattern {
public:
explicit ConvolutionOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -614,8 +608,8 @@ public:
class DivOpConversion : public ConversionPattern {
public:
explicit DivOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1, context) {
}
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -627,8 +621,8 @@ public:
class LogSoftmaxOpConversion : public ConversionPattern {
public:
explicit LogSoftmaxOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -657,8 +651,8 @@ public:
class MaxPoolOpConversion : public ConversionPattern {
public:
explicit MaxPoolOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -678,7 +672,8 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices");
return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices");
}
};
@ -686,14 +681,15 @@ public:
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
public:
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
: ConversionPattern(
mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::getOperationName(), 1,
context) {}
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::
getOperationName(),
1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices_backward");
return rewriteWithFunctionCall(op, operands, rewriter,
"max_pool2d_with_indices_backward");
}
};
@ -701,7 +697,8 @@ public:
class MMOpConversion : public ConversionPattern {
public:
explicit MMOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1, context) {}
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -714,8 +711,8 @@ public:
class MulOpConversion : public ConversionPattern {
public:
explicit MulOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1, context) {
}
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -728,8 +725,9 @@ public:
class NativeBatchNormOpConversion : public ConversionPattern {
public:
explicit NativeBatchNormOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(),
1, context) {}
: ConversionPattern(
mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -742,13 +740,15 @@ public:
class NllLoss2dBackwardOpConversion : public ConversionPattern {
public:
explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(),
1, context) {}
: ConversionPattern(
mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_backward");
return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_backward");
}
};
@ -756,13 +756,15 @@ public:
class NllLoss2dForwardOpConversion : public ConversionPattern {
public:
explicit NllLoss2dForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(),
1, context) {}
: ConversionPattern(
mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_forward");
return rewriteWithFunctionCall(op, operands, rewriter,
"nll_loss2d_forward");
}
};
@ -770,8 +772,9 @@ public:
class NllLossBackwardOpConversion : public ConversionPattern {
public:
explicit NllLossBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(),
1, context) {}
: ConversionPattern(
mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -784,13 +787,15 @@ public:
class NllLossForwardOpConversion : public ConversionPattern {
public:
explicit NllLossForwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
context) {}
: ConversionPattern(
mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward"); }
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward");
}
};
/// Lower ReLU
@ -811,13 +816,15 @@ public:
class ThresholdBackwardOpConversion : public ConversionPattern {
public:
explicit ThresholdBackwardOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(),
1, context) {}
: ConversionPattern(
mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return rewriteWithFunctionCall(op, operands, rewriter, "threshold_backward");
return rewriteWithFunctionCall(op, operands, rewriter,
"threshold_backward");
}
};
@ -825,7 +832,8 @@ public:
class TransposeOpConversion : public ConversionPattern {
public:
explicit TransposeOpConversion(MLIRContext *context)
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1, context) {}
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@ -849,21 +857,22 @@ public:
Value xVal = memRefTypeCast(rewriter, operands[0]);
// construct the shape argument
// construct the shape argument
SmallVector<Value, 8> shape;
auto co = dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
auto co =
dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
for (auto i : a.getAttributeValues())
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
// pad out the shape with -1 to make it 4d
while (shape.size() < 4)
shape.push_back(constInt(-1, 32));
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "view");
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
"view");
}
};
@ -896,9 +905,8 @@ struct ATenLoweringPass
// c++ patterns
acapPatterns.insert<
ConstantOpConversion,
AddOpConversion, ConvolutionOpConversion, ReLUOpConversion,
TransposeOpConversion, BatchNormOpConversion,
ConstantOpConversion, AddOpConversion, ConvolutionOpConversion,
ReLUOpConversion, TransposeOpConversion, BatchNormOpConversion,
NativeBatchNormOpConversion, MaxPoolOpConversion,
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
MulOpConversion, MMOpConversion, AsStridedOpConversion,

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/ATen/ATenToStd.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "npcomp/Dialect/ATen/ATenDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP;

View File

@ -115,8 +115,8 @@ std::string LivenessReport::emitJSONReport() {
for (auto v : vlist) {
int64_t vol = getTensorVolume(v.getType());
if (v.getDefiningOp()) {
if (auto a = v.getDefiningOp()->getAttrOfType<StringAttr>(
"layer_name")) {
if (auto a =
v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
auto definingOp = v.getDefiningOp();
auto ld = layerDetail.getInteger(a.getValue().str());
if (ld)

View File

@ -32,28 +32,29 @@ public:
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
return op.getOperation();
})
.def("scf_if_op",
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
PyValue cond, bool withElseRegion) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
pyResultTypes.end());
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
withElseRegion);
if (withElseRegion) {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint(),
op.getElseBodyBuilder().saveInsertionPoint());
} else {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint());
}
},
py::arg("result_types"), py::arg("cond"),
py::arg("with_else_region") = false);
.def(
"scf_if_op",
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
PyValue cond, bool withElseRegion) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
pyResultTypes.end());
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
withElseRegion);
if (withElseRegion) {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint(),
op.getElseBodyBuilder().saveInsertionPoint());
} else {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint());
}
},
py::arg("result_types"), py::arg("cond"),
py::arg("with_else_region") = false);
}
};

View File

@ -65,13 +65,14 @@ void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
"front",
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
.def("__len__", [](ThisTy &self) { return self.list.size(); })
.def("__iter__",
[](ThisTy &self) {
PyItemIterator begin(self.list.begin());
PyItemIterator end(self.list.end());
return py::make_iterator(begin, end);
},
py::keep_alive<0, 1>());
.def(
"__iter__",
[](ThisTy &self) {
PyItemIterator begin(self.list.begin());
PyItemIterator end(self.list.end());
return py::make_iterator(begin, end);
},
py::keep_alive<0, 1>());
}
//===----------------------------------------------------------------------===//
@ -194,79 +195,81 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
return self.context.shared_from_this();
})
.def("op",
[](PyDialectHelper &self, const std::string &opNameStr,
std::vector<PyType> pyResultTypes,
std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = self.pyOpBuilder.getCurrentLoc();
OperationName opName(opNameStr, opBuilder.getContext());
SmallVector<Type, 4> types(pyResultTypes.begin(),
pyResultTypes.end());
SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
MutableDictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = MutableDictionaryAttr(dictAttrs);
}
Operation *op =
Operation::create(loc, opName, types, operands, attrList);
opBuilder.insert(op);
return op;
},
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>())
.def("func_op",
[](PyDialectHelper &self, const std::string &name, PyType type,
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
if (!functionType) {
throw py::raiseValueError("Illegal function type");
}
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
// TODO: Dedup attr creation from op().
MutableDictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = MutableDictionaryAttr(dictAttrs);
}
FuncOp op =
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
/*attrs=*/attrList.getAttrs());
if (createEntryBlock) {
Block *entryBlock = new Block();
entryBlock->addArguments(functionType.getInputs());
op.getBody().push_back(entryBlock);
opBuilder.setInsertionPointToStart(entryBlock);
}
return PyOperationRef(op);
},
py::arg("name"), py::arg("type"),
py::arg("create_entry_block") = false,
py::arg("attrs") = llvm::Optional<PyAttribute>(),
R"(Creates a new `func` op, optionally creating an entry block.
.def(
"op",
[](PyDialectHelper &self, const std::string &opNameStr,
std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = self.pyOpBuilder.getCurrentLoc();
OperationName opName(opNameStr, opBuilder.getContext());
SmallVector<Type, 4> types(pyResultTypes.begin(),
pyResultTypes.end());
SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
MutableDictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = MutableDictionaryAttr(dictAttrs);
}
Operation *op =
Operation::create(loc, opName, types, operands, attrList);
opBuilder.insert(op);
return op;
},
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>())
.def(
"func_op",
[](PyDialectHelper &self, const std::string &name, PyType type,
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
if (!functionType) {
throw py::raiseValueError("Illegal function type");
}
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
// TODO: Dedup attr creation from op().
MutableDictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = MutableDictionaryAttr(dictAttrs);
}
FuncOp op =
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
/*attrs=*/attrList.getAttrs());
if (createEntryBlock) {
Block *entryBlock = new Block();
entryBlock->addArguments(functionType.getInputs());
op.getBody().push_back(entryBlock);
opBuilder.setInsertionPointToStart(entryBlock);
}
return PyOperationRef(op);
},
py::arg("name"), py::arg("type"),
py::arg("create_entry_block") = false,
py::arg("attrs") = llvm::Optional<PyAttribute>(),
R"(Creates a new `func` op, optionally creating an entry block.
If an entry block is created, the builder will be positioned
to its start.)")
.def("select_op",
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
PyValue falseValue) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
return PyOperationRef(opBuilder.create<SelectOp>(
loc, conditionValue, trueValue, falseValue));
},
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
.def(
"select_op",
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
PyValue falseValue) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
return PyOperationRef(opBuilder.create<SelectOp>(
loc, conditionValue, trueValue, falseValue));
},
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
.def("return_op",
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
@ -288,11 +291,12 @@ void PyDialectHelper::bind(py::module m) {
[](PyDialectHelper &self) -> PyType {
return IndexType::get(self.getContext());
})
.def("integer_type",
[](PyDialectHelper &self, unsigned width) -> PyType {
return IntegerType::get(width, self.getContext());
},
py::arg("width") = 32)
.def(
"integer_type",
[](PyDialectHelper &self, unsigned width) -> PyType {
return IntegerType::get(width, self.getContext());
},
py::arg("width") = 32)
.def_property_readonly("i1_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(1, self.getContext());
@ -319,20 +323,21 @@ void PyDialectHelper::bind(py::module m) {
return FloatType::get(StandardTypes::F64,
self.getContext());
})
.def("tensor_type",
[](PyDialectHelper &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
if (!elementType.type) {
throw py::raiseValueError("Null element type");
}
if (shape) {
return RankedTensorType::get(*shape, elementType.type);
} else {
return UnrankedTensorType::get(elementType.type);
}
},
py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def(
"tensor_type",
[](PyDialectHelper &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
if (!elementType.type) {
throw py::raiseValueError("Null element type");
}
if (shape) {
return RankedTensorType::get(*shape, elementType.type);
} else {
return UnrankedTensorType::get(elementType.type);
}
},
py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def("function_type",
[](PyDialectHelper &self, std::vector<PyType> inputs,
std::vector<PyType> results) -> PyType {
@ -367,21 +372,24 @@ void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library";
// Globals.
m.def("emit_error",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def("emit_warning",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def("emit_remark",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def(
"emit_error",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def(
"emit_warning",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def(
"emit_remark",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
},
py::arg("loc"), py::arg("message"));
// Python only types.
PyDialectHelper::bind(m);
@ -426,25 +434,27 @@ void PyContext::bind(py::module m) {
return PyModuleOp(self.shared_from_this(), m);
})
.def("parse_asm", &PyContext::parseAsm)
.def("new_builder",
[](PyContext &self) {
// Note: we collapse the Builder and OpBuilder into one because
// there is little reason to expose the inheritance hierarchy to
// Python.
return PyOpBuilder(self);
},
py::keep_alive<0, 1>())
.def(
"new_builder",
[](PyContext &self) {
// Note: we collapse the Builder and OpBuilder into one because
// there is little reason to expose the inheritance hierarchy to
// Python.
return PyOpBuilder(self);
},
py::keep_alive<0, 1>())
.def("identifier",
[](PyContext &self, std::string s) -> PyIdentifier {
return Identifier::get(s, &self.context);
})
.def("file_line_col_loc_attr",
[](PyContext &self, PyIdentifier filename, unsigned line,
unsigned column) -> PyAttribute {
return static_cast<LocationAttr>(FileLineColLoc::get(
filename.identifier, line, column, &self.context));
},
py::arg("filename"), py::arg("line"), py::arg("column"))
.def(
"file_line_col_loc_attr",
[](PyContext &self, PyIdentifier filename, unsigned line,
unsigned column) -> PyAttribute {
return static_cast<LocationAttr>(FileLineColLoc::get(
filename.identifier, line, column, &self.context));
},
py::arg("filename"), py::arg("line"), py::arg("column"))
// Salient functions from Builder.
.def("parse_type",
[](PyContext &self, const std::string &asmText) {
@ -456,14 +466,15 @@ void PyContext::bind(py::module m) {
}
return PyType(t);
})
.def("integer_attr",
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
if (!type.type.isa<IntegerType>()) {
throw py::raiseValueError("Expected IntegerType");
}
return IntegerAttr::get(type.type, value);
},
py::arg("type"), py::arg("value"))
.def(
"integer_attr",
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
if (!type.type.isa<IntegerType>()) {
throw py::raiseValueError("Expected IntegerType");
}
return IntegerAttr::get(type.type, value);
},
py::arg("type"), py::arg("value"))
.def("float_attr",
[](PyContext &self, PyType type, double value) -> PyAttribute {
if (!type.type.isa<FloatType>()) {
@ -512,20 +523,20 @@ void PyContext::bind(py::module m) {
}
return ArrayAttr::get(attrs, &self.context);
})
.def("dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
Py_buffer *view = new Py_buffer();
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
delete view;
throw py::error_already_set();
}
py::buffer_info array_info(view);
return createDenseElementsAttrFromBuffer(&self.context,
array_info);
},
py::arg("array"))
.def(
"dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
Py_buffer *view = new Py_buffer();
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
delete view;
throw py::error_already_set();
}
py::buffer_info array_info(view);
return createDenseElementsAttrFromBuffer(&self.context, array_info);
},
py::arg("array"))
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
return UnitAttr::get(&self.context);
});
@ -905,67 +916,74 @@ void PyBaseOpBuilder::bind(py::module m) {
void PyOpBuilder::bind(py::module m) {
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
.def_property("current_loc",
[](PyOpBuilder &self) -> PyAttribute {
return static_cast<Attribute>(self.getCurrentLoc());
},
[](PyOpBuilder &self, PyAttribute attr) {
auto loc_attr =
attr.attr.dyn_cast_or_null<LocationAttr>();
if (!loc_attr) {
throw py::raiseValueError("Expected a LocationAttr");
}
self.setCurrentLoc(Location(loc_attr));
})
.def_property("insertion_point",
[](PyOpBuilder &self) {
return self.getBuilder(true).saveInsertionPoint();
},
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
self.getBuilder(false).restoreInsertionPoint(ip);
})
.def("set_file_line_col",
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
unsigned column) {
Location loc = FileLineColLoc::get(filename.identifier, line,
column, self.getContext());
self.setCurrentLoc(loc);
},
py::arg("filename"), py::arg("line"), py::arg("column"),
"Shortcut to set a FileLineCol current location")
.def_property(
"current_loc",
[](PyOpBuilder &self) -> PyAttribute {
return static_cast<Attribute>(self.getCurrentLoc());
},
[](PyOpBuilder &self, PyAttribute attr) {
auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
if (!loc_attr) {
throw py::raiseValueError("Expected a LocationAttr");
}
self.setCurrentLoc(Location(loc_attr));
})
.def_property(
"insertion_point",
[](PyOpBuilder &self) {
return self.getBuilder(true).saveInsertionPoint();
},
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
self.getBuilder(false).restoreInsertionPoint(ip);
})
.def(
"set_file_line_col",
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
unsigned column) {
Location loc = FileLineColLoc::get(filename.identifier, line,
column, self.getContext());
self.setCurrentLoc(loc);
},
py::arg("filename"), py::arg("line"), py::arg("column"),
"Shortcut to set a FileLineCol current location")
.def("clear_insertion_point",
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
.def("insert_op_before",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPoint(op);
},
"Sets the insertion point to just before the specified op.")
.def("insert_op_after",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPointAfter(op);
},
"Sets the insertion point to just after the specified op.")
.def("insert_block_start",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToStart(&block.block);
},
"Sets the insertion point to the start of the block.")
.def("insert_block_end",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToEnd(&block.block);
},
"Sets the insertion point to the end of the block.")
.def("insert_before_terminator",
[](PyOpBuilder &self, PyBlockRef block) {
auto *terminator = block.block.getTerminator();
if (!terminator) {
throw py::raiseValueError("Block has no terminator");
}
self.builder.setInsertionPoint(terminator);
},
"Sets the insertion point to just before the block terminator.");
.def(
"insert_op_before",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPoint(op);
},
"Sets the insertion point to just before the specified op.")
.def(
"insert_op_after",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPointAfter(op);
},
"Sets the insertion point to just after the specified op.")
.def(
"insert_block_start",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToStart(&block.block);
},
"Sets the insertion point to the start of the block.")
.def(
"insert_block_end",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToEnd(&block.block);
},
"Sets the insertion point to the end of the block.")
.def(
"insert_before_terminator",
[](PyOpBuilder &self, PyBlockRef block) {
auto *terminator = block.block.getTerminator();
if (!terminator) {
throw py::raiseValueError("Block has no terminator");
}
self.builder.setInsertionPoint(terminator);
},
"Sets the insertion point to just before the block terminator.");
}
} // namespace mlir

View File

@ -32,13 +32,14 @@ void PyPassManager::bind(py::module m) {
py::class_<PyPassManager>(m, "PassManager")
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
py::arg("verifyModules") = true)
.def("enableCrashReproducerGeneration",
[](PyPassManager &self, std::string outputFile,
bool genLocalReproducer) {
self.passManager.enableCrashReproducerGeneration(
outputFile, genLocalReproducer);
},
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
.def(
"enableCrashReproducerGeneration",
[](PyPassManager &self, std::string outputFile,
bool genLocalReproducer) {
self.passManager.enableCrashReproducerGeneration(
outputFile, genLocalReproducer);
},
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
.def("__len__",
[](PyPassManager &self) { return self.passManager.size(); })
.def("__str__",

View File

@ -48,18 +48,19 @@ public:
return Basicpy::NoneType::get(
self.getContext());
})
.def("basicpy_SlotObject_type",
[](BasicpyDialectHelper &self, std::string className,
py::args pySlotTypes) -> PyType {
SmallVector<Type, 4> slotTypes;
for (auto pySlotType : pySlotTypes) {
slotTypes.push_back(pySlotType.cast<PyType>());
}
auto classNameAttr =
StringAttr::get(className, self.getContext());
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
},
py::arg("className"))
.def(
"basicpy_SlotObject_type",
[](BasicpyDialectHelper &self, std::string className,
py::args pySlotTypes) -> PyType {
SmallVector<Type, 4> slotTypes;
for (auto pySlotType : pySlotTypes) {
slotTypes.push_back(pySlotType.cast<PyType>());
}
auto classNameAttr =
StringAttr::get(className, self.getContext());
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
},
py::arg("className"))
.def_property_readonly("basicpy_StrType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::StrType::get(