mirror of https://github.com/llvm/torch-mlir
Format sources.
parent
de38caa547
commit
fc4f374345
|
@ -14,7 +14,7 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
|
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
|
||||||
} // namespace aten
|
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -11,9 +11,9 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "aten-op-stats"
|
#define DEBUG_TYPE "aten-op-stats"
|
||||||
|
|
||||||
|
@ -25,7 +25,8 @@ namespace NPCOMP {
|
||||||
namespace aten {
|
namespace aten {
|
||||||
|
|
||||||
// Return the op statistics for conv2d-like operations.
|
// Return the op statistics for conv2d-like operations.
|
||||||
template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
|
template <class T>
|
||||||
|
std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
|
||||||
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
|
|
||||||
|
@ -66,11 +67,13 @@ template <class T> std::map<std::string, uint64_t> getConv2dStatistics(T *o, uin
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the op statistics for conv2dBackward-like operations.
|
// Return the op statistics for conv2dBackward-like operations.
|
||||||
template<typename T>
|
template <typename T>
|
||||||
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t groups) {
|
std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op,
|
||||||
|
uint64_t groups) {
|
||||||
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
TensorType dx_out_resultTy = op.getResult(0).getType().template cast<TensorType>();
|
TensorType dx_out_resultTy =
|
||||||
|
op.getResult(0).getType().template cast<TensorType>();
|
||||||
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
|
uint64_t dx_out_volume = getTensorVolume(dx_out_resultTy);
|
||||||
|
|
||||||
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
|
TensorType weightTy = op.getOperand(2).getType().template cast<TensorType>();
|
||||||
|
@ -79,7 +82,7 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
|
||||||
uint64_t kernel_width = weightTy.getShape()[2];
|
uint64_t kernel_width = weightTy.getShape()[2];
|
||||||
uint64_t kernel_height = weightTy.getShape()[3];
|
uint64_t kernel_height = weightTy.getShape()[3];
|
||||||
|
|
||||||
uint64_t MACs_per_loss =
|
uint64_t MACs_per_loss =
|
||||||
(loss_in_depth / groups) * kernel_height * kernel_width;
|
(loss_in_depth / groups) * kernel_height * kernel_width;
|
||||||
|
|
||||||
uint64_t total_MACs = dx_out_volume * MACs_per_loss;
|
uint64_t total_MACs = dx_out_volume * MACs_per_loss;
|
||||||
|
@ -119,7 +122,6 @@ std::map<std::string, uint64_t> getConv2dBackwardStatistics(T op, uint64_t group
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Return a model of the number of bytes needed to represent the operand of
|
// Return a model of the number of bytes needed to represent the operand of
|
||||||
// the given convolution-like operation with the given index. The shape is
|
// the given convolution-like operation with the given index. The shape is
|
||||||
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
|
// assumed to be in NCHW order with a simple tiled model of data reuse. TODO:
|
||||||
|
@ -222,8 +224,7 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the op statistics for matrixmultiply-like operations.
|
// Return the op statistics for matrixmultiply-like operations.
|
||||||
template<typename T>
|
template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
|
||||||
std::map<std::string, uint64_t> getMMOpStatistics(T op) {
|
|
||||||
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ namespace aten {
|
||||||
// #define GEN_PASS_CLASSES
|
// #define GEN_PASS_CLASSES
|
||||||
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
|
// #include "npcomp/Dialect/ATen/ATenPasses.h.inc"
|
||||||
|
|
||||||
void registerATenPasses();
|
void registerATenPasses();
|
||||||
} // namespace aten
|
} // namespace aten
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -44,33 +44,36 @@ void mlir::npcomp::python::defineBackendIREEModule(py::module m) {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("build_flow_transform_pass_pipeline",
|
m.def(
|
||||||
[](PyPassManager &pm) {
|
"build_flow_transform_pass_pipeline",
|
||||||
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
|
[](PyPassManager &pm) {
|
||||||
pm.passManager);
|
mlir::iree_compiler::IREE::Flow::buildFlowTransformPassPipeline(
|
||||||
},
|
pm.passManager);
|
||||||
py::arg("pm"),
|
},
|
||||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
py::arg("pm"),
|
||||||
m.def("build_hal_transform_pass_pipeline",
|
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||||
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
|
m.def(
|
||||||
mlir::iree_compiler::IREE::HAL::TargetOptions options;
|
"build_hal_transform_pass_pipeline",
|
||||||
if (targetBackends.empty()) {
|
[](PyPassManager &pm, std::vector<std::string> targetBackends) {
|
||||||
options.targets =
|
mlir::iree_compiler::IREE::HAL::TargetOptions options;
|
||||||
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
|
if (targetBackends.empty()) {
|
||||||
} else {
|
options.targets =
|
||||||
options.targets = std::move(targetBackends);
|
mlir::iree_compiler::IREE::HAL::getRegisteredTargetBackends();
|
||||||
}
|
} else {
|
||||||
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(
|
options.targets = std::move(targetBackends);
|
||||||
pm.passManager, options);
|
}
|
||||||
},
|
iree_compiler::IREE::HAL::buildHALTransformPassPipeline(pm.passManager,
|
||||||
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
|
options);
|
||||||
py::doc("Builds a pass pipeline for top-level Flow import"));
|
},
|
||||||
m.def("build_vm_transform_pass_pipeline",
|
py::arg("pm"), py::arg("target_backends") = std::vector<std::string>(),
|
||||||
[](PyPassManager &pm) {
|
py::doc("Builds a pass pipeline for top-level Flow import"));
|
||||||
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
|
m.def(
|
||||||
pm.passManager);
|
"build_vm_transform_pass_pipeline",
|
||||||
},
|
[](PyPassManager &pm) {
|
||||||
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
|
mlir::iree_compiler::IREE::VM::buildVMTransformPassPipeline(
|
||||||
|
pm.passManager);
|
||||||
|
},
|
||||||
|
py::arg("pm"), py::doc("Builds the VM transformation pipeline"));
|
||||||
m.def("translate_to_vm_bytecode", [](PyModuleOp &module) {
|
m.def("translate_to_vm_bytecode", [](PyModuleOp &module) {
|
||||||
// TODO: Make the options parameterizable.
|
// TODO: Make the options parameterizable.
|
||||||
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options;
|
mlir::iree_compiler::IREE::VM::BytecodeTargetOptions options;
|
||||||
|
|
|
@ -89,38 +89,39 @@ void npcomp::python::defineBackendRefJitModule(py::module m) {
|
||||||
JITModule::buildBackendCompilationPipeline(pm.passManager);
|
JITModule::buildBackendCompilationPipeline(pm.passManager);
|
||||||
});
|
});
|
||||||
py::class_<JITModule>(m, "JITModule")
|
py::class_<JITModule>(m, "JITModule")
|
||||||
.def_static("from_compiled_module",
|
.def_static(
|
||||||
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
"from_compiled_module",
|
||||||
-> std::unique_ptr<JITModule> {
|
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
||||||
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
-> std::unique_ptr<JITModule> {
|
||||||
pySharedLibs.end());
|
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
||||||
auto jitModule =
|
pySharedLibs.end());
|
||||||
checkError(JITModule::fromCompiledModule(
|
auto jitModule = checkError(
|
||||||
module.moduleOp, sharedLibs),
|
JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
|
||||||
"error creating JITModule: ");
|
"error creating JITModule: ");
|
||||||
return jitModule;
|
return jitModule;
|
||||||
},
|
},
|
||||||
py::arg("module"), py::arg("shared_libs"))
|
py::arg("module"), py::arg("shared_libs"))
|
||||||
.def("invoke",
|
.def(
|
||||||
[](JITModule &self, std::string functionName,
|
"invoke",
|
||||||
std::vector<py::buffer> inputs) {
|
[](JITModule &self, std::string functionName,
|
||||||
// Prepare inputs.
|
std::vector<py::buffer> inputs) {
|
||||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
// Prepare inputs.
|
||||||
inputTensors.reserve(inputs.size());
|
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
||||||
for (py::buffer &inputBuffer : inputs) {
|
inputTensors.reserve(inputs.size());
|
||||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
for (py::buffer &inputBuffer : inputs) {
|
||||||
}
|
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
||||||
|
}
|
||||||
|
|
||||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
||||||
"error invoking JIT function: ");
|
"error invoking JIT function: ");
|
||||||
std::vector<py::array> outputArrays;
|
std::vector<py::array> outputArrays;
|
||||||
outputArrays.reserve(outputs.size());
|
outputArrays.reserve(outputs.size());
|
||||||
for (Ref<Tensor> &outputTensor : outputs) {
|
for (Ref<Tensor> &outputTensor : outputs) {
|
||||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
||||||
}
|
}
|
||||||
return outputArrays;
|
return outputArrays;
|
||||||
},
|
},
|
||||||
py::arg("function_name"), py::arg("inputs"));
|
py::arg("function_name"), py::arg("inputs"));
|
||||||
|
|
||||||
// A Ref<Tensor> needs to be bound because we use it as a base for the
|
// A Ref<Tensor> needs to be bound because we use it as a base for the
|
||||||
// ndarray (the array retains a reference to it). Users should not encounter
|
// ndarray (the array retains a reference to it). Users should not encounter
|
||||||
|
|
|
@ -55,7 +55,7 @@ std::map<std::string, uint64_t> AdaptiveAvgPool2dBackwardOp::getStatistics() {
|
||||||
return toReturn;
|
return toReturn;
|
||||||
}
|
}
|
||||||
|
|
||||||
// add
|
// add
|
||||||
std::map<std::string, uint64_t> AddOp::getStatistics() {
|
std::map<std::string, uint64_t> AddOp::getStatistics() {
|
||||||
|
|
||||||
std::map<std::string, uint64_t> toReturn;
|
std::map<std::string, uint64_t> toReturn;
|
||||||
|
@ -222,7 +222,8 @@ uint64_t ConvolutionOp::getResultTransferVolume(unsigned int idx, bool write) {
|
||||||
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
|
std::map<std::string, uint64_t> ConvolutionBackwardOp::getStatistics() {
|
||||||
return getConv2dBackwardStatistics(*this, 1);
|
return getConv2dBackwardStatistics(*this, 1);
|
||||||
}
|
}
|
||||||
std::map<std::string, uint64_t> ConvolutionBackwardOverrideableOp::getStatistics() {
|
std::map<std::string, uint64_t>
|
||||||
|
ConvolutionBackwardOverrideableOp::getStatistics() {
|
||||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(groups().getDefiningOp());
|
||||||
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
auto ia = co.template getAttrOfType<IntegerAttr>("value");
|
||||||
uint64_t groups = ia.getValue().getZExtValue();
|
uint64_t groups = ia.getValue().getZExtValue();
|
||||||
|
@ -463,7 +464,7 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
|
||||||
// getMMOpStatistics(*this);
|
// getMMOpStatistics(*this);
|
||||||
// }
|
// }
|
||||||
std::map<std::string, uint64_t> MmOp::getStatistics() {
|
std::map<std::string, uint64_t> MmOp::getStatistics() {
|
||||||
return getMMOpStatistics(*this );
|
return getMMOpStatistics(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mul
|
// mul
|
||||||
|
|
|
@ -61,7 +61,8 @@ namespace {
|
||||||
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
||||||
if (val.getType() == destTy)
|
if (val.getType() == destTy)
|
||||||
return val;
|
return val;
|
||||||
return builder.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
|
return builder
|
||||||
|
.create<mlir::NPCOMP::aten::TypeCastOp>(val.getLoc(), destTy, val)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,11 +70,11 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
||||||
/// unknown shape.
|
/// unknown shape.
|
||||||
static MemRefType getShapeErasedMemRefType(MemRefType type) {
|
static MemRefType getShapeErasedMemRefType(MemRefType type) {
|
||||||
std::vector<int64_t> shape = type.getShape();
|
std::vector<int64_t> shape = type.getShape();
|
||||||
for(int i = 0; i < shape.size(); i++) {
|
for (int i = 0; i < shape.size(); i++) {
|
||||||
shape[i] = -1;
|
shape[i] = -1;
|
||||||
}
|
}
|
||||||
return MemRefType::get(shape, type.getElementType(),
|
return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
|
||||||
type.getAffineMaps(), type.getMemorySpace());
|
type.getMemorySpace());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a type cast to memref
|
/// Create a type cast to memref
|
||||||
|
@ -82,14 +83,12 @@ static Value memRefTypeCast(PatternRewriter &builder, Value val) {
|
||||||
|
|
||||||
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
|
if (auto memrefTy = type.dyn_cast<MemRefType>()) {
|
||||||
MemRefType newType = getShapeErasedMemRefType(memrefTy);
|
MemRefType newType = getShapeErasedMemRefType(memrefTy);
|
||||||
return builder.create<MemRefCastOp>(val.getLoc(),
|
return builder.create<MemRefCastOp>(val.getLoc(), val, newType).getResult();
|
||||||
val, newType)
|
|
||||||
.getResult();
|
|
||||||
}
|
}
|
||||||
if (auto tensorTy = type.dyn_cast<TensorType>()) {
|
if (auto tensorTy = type.dyn_cast<TensorType>()) {
|
||||||
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
|
auto memRefType = mlir::MemRefType::get(tensorTy.getShape(),
|
||||||
tensorTy.getElementType(), {}, 0);
|
tensorTy.getElementType(), {}, 0);
|
||||||
return typeCast(builder, val, memRefType);
|
return typeCast(builder, val, memRefType);
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
@ -186,7 +185,7 @@ static std::string getSimplyMangledFuncName(std::string prefix,
|
||||||
ret = ret + sep + getSimplyMangledType(t);
|
ret = ret + sep + getSimplyMangledType(t);
|
||||||
for (const Type t : operTy) {
|
for (const Type t : operTy) {
|
||||||
std::string s = getSimplyMangledType(t);
|
std::string s = getSimplyMangledType(t);
|
||||||
if(s.size() > 0)
|
if (s.size() > 0)
|
||||||
ret = ret + sep + getSimplyMangledType(t);
|
ret = ret + sep + getSimplyMangledType(t);
|
||||||
}
|
}
|
||||||
ret += "_out";
|
ret += "_out";
|
||||||
|
@ -194,25 +193,22 @@ static std::string getSimplyMangledFuncName(std::string prefix,
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
static std::string getSimplyMangledFuncName(std::string prefix,
|
static std::string getSimplyMangledFuncName(std::string prefix,
|
||||||
FunctionType fnTy) {
|
FunctionType fnTy) {
|
||||||
|
|
||||||
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
|
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string getMangledFuncName(std::string prefix,
|
std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
|
||||||
FunctionType fnTy) {
|
|
||||||
return getSimplyMangledFuncName(prefix, fnTy);
|
return getSimplyMangledFuncName(prefix, fnTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string getMangledFuncName(std::string prefix,
|
std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
|
||||||
ArrayRef<Type> opTys,
|
|
||||||
ArrayRef<Type> retTys) {
|
ArrayRef<Type> retTys) {
|
||||||
return getSimplyMangledFuncName(prefix, opTys, retTys);
|
return getSimplyMangledFuncName(prefix, opTys, retTys);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
|
static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
|
||||||
ArrayRef<Value> operands,
|
ArrayRef<Value> operands, ArrayRef<Type> retTys) {
|
||||||
ArrayRef<Type> retTys) {
|
|
||||||
Builder builder(module);
|
Builder builder(module);
|
||||||
|
|
||||||
SmallVector<Type, 8> tys;
|
SmallVector<Type, 8> tys;
|
||||||
|
@ -242,8 +238,8 @@ static FuncOp getATenFn(ModuleOp module, std::string mangledFunctionName,
|
||||||
class AddOpConversion_affine : public ConversionPattern {
|
class AddOpConversion_affine : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit AddOpConversion_affine(MLIRContext *context)
|
explicit AddOpConversion_affine(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
|
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
|
||||||
}
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -310,78 +306,72 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Replace the given operation with a call to the given function.
|
// Replace the given operation with a call to the given function.
|
||||||
// The function is assumed to accept memrefs and scalar types and return
|
// The function is assumed to accept memrefs and scalar types and return
|
||||||
// Memrefs. Here the result types are converted back to the result types of op,
|
// Memrefs. Here the result types are converted back to the result types of op,
|
||||||
// but operands are NOT converted. This allows non-standard mappings from
|
// but operands are NOT converted. This allows non-standard mappings from
|
||||||
// operand types to function types.
|
// operand types to function types.
|
||||||
LogicalResult
|
LogicalResult rewriteWithVoidFunctionCallExplicit(
|
||||||
rewriteWithVoidFunctionCallExplicit(Operation *op,
|
Operation *op, ArrayRef<Value> callops, ArrayRef<Value> operands,
|
||||||
ArrayRef<Value> callops,
|
ConversionPatternRewriter &rewriter, std::string functionName) {
|
||||||
ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
|
||||||
std::string functionName) {
|
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
edsc::ScopedContext scope(rewriter, loc);
|
edsc::ScopedContext scope(rewriter, loc);
|
||||||
|
|
||||||
// The original operation types.
|
// The original operation types.
|
||||||
SmallVector<Type, 8> opTys;
|
SmallVector<Type, 8> opTys;
|
||||||
// Shape erased versions of the original operation types.
|
// Shape erased versions of the original operation types.
|
||||||
SmallVector<Type, 8> erasedOpTys;
|
SmallVector<Type, 8> erasedOpTys;
|
||||||
for (const Value &o: callops) {
|
for (const Value &o : callops) {
|
||||||
Type t = o.getType();
|
Type t = o.getType();
|
||||||
opTys.push_back(t);
|
opTys.push_back(t);
|
||||||
if (t.isa<MemRefType>())
|
if (t.isa<MemRefType>())
|
||||||
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
|
erasedOpTys.push_back(getShapeErasedMemRefType(t.cast<MemRefType>()));
|
||||||
else
|
else
|
||||||
erasedOpTys.push_back(t);
|
erasedOpTys.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value> newOps = callops;
|
||||||
|
SmallVector<Value, 8> newResults;
|
||||||
|
|
||||||
|
// Result types of the original operation, converted to memrefs.
|
||||||
|
SmallVector<Type, 8> retTys;
|
||||||
|
// Erased version of the return type. This is the return types of the
|
||||||
|
// generated function call.
|
||||||
|
SmallVector<Type, 8> erasedRetTys;
|
||||||
|
for (const auto &o : op->getResults()) {
|
||||||
|
Type t = o.getType();
|
||||||
|
if (t.isa<TensorType>()) {
|
||||||
|
TensorType tensorResultTy = t.cast<TensorType>();
|
||||||
|
MemRefType memRefResultTy = mlir::MemRefType::get(
|
||||||
|
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
|
||||||
|
MemRefType erasedMemRefResultTy =
|
||||||
|
getShapeErasedMemRefType(memRefResultTy);
|
||||||
|
retTys.push_back(memRefResultTy);
|
||||||
|
|
||||||
|
// assume memRefResultTy has known shape, so we don't need any
|
||||||
|
// dynamic dimensions for the alloc.
|
||||||
|
assert(memRefResultTy.hasStaticShape());
|
||||||
|
Value allocVal = rewriter.create<AllocOp>(op->getLoc(), memRefResultTy);
|
||||||
|
Value castVal = memRefTypeCast(rewriter, allocVal);
|
||||||
|
newOps.push_back(castVal);
|
||||||
|
newResults.push_back(allocVal);
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Value> newOps = callops;
|
SmallVector<Type, 8> empty;
|
||||||
SmallVector<Value, 8> newResults;
|
std::string mangledFunctionName =
|
||||||
|
getMangledFuncName(functionName, opTys, retTys);
|
||||||
|
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
||||||
|
mangledFunctionName, newOps, empty);
|
||||||
|
|
||||||
// Result types of the original operation, converted to memrefs.
|
auto new_call =
|
||||||
SmallVector<Type, 8> retTys;
|
callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
|
||||||
// Erased version of the return type. This is the return types of the
|
|
||||||
// generated function call.
|
|
||||||
SmallVector<Type, 8> erasedRetTys;
|
|
||||||
for (const auto &o: op->getResults()) {
|
|
||||||
Type t = o.getType();
|
|
||||||
if (t.isa<TensorType>()) {
|
|
||||||
TensorType tensorResultTy = t.cast<TensorType>();
|
|
||||||
MemRefType memRefResultTy =
|
|
||||||
mlir::MemRefType::get(tensorResultTy.getShape(),
|
|
||||||
tensorResultTy.getElementType(), {}, 0);
|
|
||||||
MemRefType erasedMemRefResultTy = getShapeErasedMemRefType(memRefResultTy);
|
|
||||||
retTys.push_back(memRefResultTy);
|
|
||||||
|
|
||||||
// assume memRefResultTy has known shape, so we don't need any
|
rewriter.replaceOp(op, newResults);
|
||||||
// dynamic dimensions for the alloc.
|
return success();
|
||||||
assert(memRefResultTy.hasStaticShape());
|
|
||||||
Value allocVal = rewriter.create<AllocOp>(op->getLoc(),
|
|
||||||
memRefResultTy);
|
|
||||||
Value castVal = memRefTypeCast(rewriter, allocVal);
|
|
||||||
newOps.push_back(castVal);
|
|
||||||
newResults.push_back(allocVal);
|
|
||||||
} else {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Type, 8> empty;
|
|
||||||
std::string mangledFunctionName = getMangledFuncName(functionName, opTys, retTys);
|
|
||||||
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
|
||||||
mangledFunctionName,
|
|
||||||
newOps,
|
|
||||||
empty);
|
|
||||||
|
|
||||||
auto new_call = callOperation(empty,
|
|
||||||
rewriter.getSymbolRefAttr(funcOp), newOps);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, newResults);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace the given operation with a call to the given function.
|
// Replace the given operation with a call to the given function.
|
||||||
|
@ -389,54 +379,53 @@ rewriteWithVoidFunctionCallExplicit(Operation *op,
|
||||||
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
|
// Memrefs. Other operand types (e.g. aten.list and tensor<> are converted
|
||||||
// appropriately. The called function passes results of the original function
|
// appropriately. The called function passes results of the original function
|
||||||
// as memref arguments at the end of the original set of operands.
|
// as memref arguments at the end of the original set of operands.
|
||||||
LogicalResult
|
LogicalResult rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
|
||||||
rewriteWithFunctionCall(Operation *op, ArrayRef<Value> operands,
|
ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter,
|
std::string functionName) {
|
||||||
std::string functionName) {
|
auto loc = op->getLoc();
|
||||||
auto loc = op->getLoc();
|
edsc::ScopedContext scope(rewriter, loc);
|
||||||
edsc::ScopedContext scope(rewriter, loc);
|
|
||||||
|
|
||||||
// Convert the arguments to the original call.
|
// Convert the arguments to the original call.
|
||||||
SmallVector<Value, 8> callops;
|
SmallVector<Value, 8> callops;
|
||||||
for (auto &o: operands) {
|
for (auto &o : operands) {
|
||||||
Type t = o.getType();
|
Type t = o.getType();
|
||||||
if (t.isa<MemRefType>()) {
|
if (t.isa<MemRefType>()) {
|
||||||
// Cast it to some memref type that we accept
|
// Cast it to some memref type that we accept
|
||||||
callops.push_back(memRefTypeCast(rewriter, o));
|
callops.push_back(memRefTypeCast(rewriter, o));
|
||||||
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
|
} else if (t.isa<IntegerType>() || t.isa<FloatType>()) {
|
||||||
callops.push_back(o);
|
callops.push_back(o);
|
||||||
} else if (t.isa<ATenListType>()) {
|
} else if (t.isa<ATenListType>()) {
|
||||||
// FIXME: lots of assumptions here.
|
// FIXME: lots of assumptions here.
|
||||||
auto unpack = [](auto &op, auto &v) -> void {
|
auto unpack = [](auto &op, auto &v) -> void {
|
||||||
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
|
auto co = cast<mlir::NPCOMP::aten::ConstantOp>(op.getDefiningOp());
|
||||||
DenseElementsAttr a =
|
DenseElementsAttr a =
|
||||||
co.template getAttrOfType<DenseElementsAttr>("value");
|
co.template getAttrOfType<DenseElementsAttr>("value");
|
||||||
for (auto i : a.getIntValues())
|
for (auto i : a.getIntValues())
|
||||||
v.push_back(i.getSExtValue());
|
v.push_back(i.getSExtValue());
|
||||||
};
|
};
|
||||||
std::vector<uint64_t> values;
|
std::vector<uint64_t> values;
|
||||||
unpack(o, values);
|
unpack(o, values);
|
||||||
callops.push_back(constInt(values[0], 32));
|
callops.push_back(constInt(values[0], 32));
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, functionName);
|
}
|
||||||
|
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||||
|
functionName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Lower Add
|
/// Lower Add
|
||||||
template<typename Op>
|
template <typename Op>
|
||||||
class ATenFunctionCallConversion : public ConversionPattern {
|
class ATenFunctionCallConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit ATenFunctionCallConversion(MLIRContext *context)
|
explicit ATenFunctionCallConversion(MLIRContext *context)
|
||||||
: ConversionPattern(Op::getOperationName(), 1, context) {
|
: ConversionPattern(Op::getOperationName(), 1, context) {}
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, Op::getFunctionConversionName());
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
Op::getFunctionConversionName());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -444,8 +433,8 @@ public:
|
||||||
class ConstantOpConversion : public ConversionPattern {
|
class ConstantOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit ConstantOpConversion(MLIRContext *context)
|
explicit ConstantOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1, context) {
|
: ConversionPattern(mlir::NPCOMP::aten::ConstantOp::getOperationName(), 1,
|
||||||
}
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -459,14 +448,15 @@ public:
|
||||||
Type t = result.getType();
|
Type t = result.getType();
|
||||||
if (t.isa<IntegerType>()) {
|
if (t.isa<IntegerType>()) {
|
||||||
auto it = t.cast<IntegerType>();
|
auto it = t.cast<IntegerType>();
|
||||||
if(it.getWidth() > 1) {
|
if (it.getWidth() > 1) {
|
||||||
auto a = op->getAttrOfType<IntegerAttr>("value");
|
auto a = op->getAttrOfType<IntegerAttr>("value");
|
||||||
SmallVector<Value, 8> newValues {rewriter.create<mlir::ConstantOp>(loc, a)};
|
SmallVector<Value, 8> newValues{
|
||||||
|
rewriter.create<mlir::ConstantOp>(loc, a)};
|
||||||
rewriter.replaceOp(op, newValues);
|
rewriter.replaceOp(op, newValues);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
auto a = op->getAttrOfType<BoolAttr>("value");
|
auto a = op->getAttrOfType<BoolAttr>("value");
|
||||||
SmallVector<Value, 8> newValues {constInt(a.getValue(), it.getWidth())};
|
SmallVector<Value, 8> newValues{constInt(a.getValue(), it.getWidth())};
|
||||||
rewriter.replaceOp(op, newValues);
|
rewriter.replaceOp(op, newValues);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -485,8 +475,8 @@ public:
|
||||||
class AddOpConversion : public ConversionPattern {
|
class AddOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit AddOpConversion(MLIRContext *context)
|
explicit AddOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1, context) {
|
: ConversionPattern(mlir::NPCOMP::aten::AddOp::getOperationName(), 1,
|
||||||
}
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -513,8 +503,8 @@ public:
|
||||||
class AsStridedOpConversion : public ConversionPattern {
|
class AsStridedOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit AsStridedOpConversion(MLIRContext *context)
|
explicit AsStridedOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(), 1,
|
: ConversionPattern(mlir::NPCOMP::aten::AsStridedOp::getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -527,7 +517,8 @@ public:
|
||||||
// construct the shape argument
|
// construct the shape argument
|
||||||
std::vector<Value> shape;
|
std::vector<Value> shape;
|
||||||
std::vector<int64_t> result_shape;
|
std::vector<int64_t> result_shape;
|
||||||
auto co0 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
auto co0 =
|
||||||
|
cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||||
DenseElementsAttr a0 =
|
DenseElementsAttr a0 =
|
||||||
co0.template getAttrOfType<DenseElementsAttr>("value");
|
co0.template getAttrOfType<DenseElementsAttr>("value");
|
||||||
for (auto i : a0.getAttributeValues())
|
for (auto i : a0.getAttributeValues())
|
||||||
|
@ -539,7 +530,8 @@ public:
|
||||||
|
|
||||||
// construct the stride argument
|
// construct the stride argument
|
||||||
std::vector<Value> stride;
|
std::vector<Value> stride;
|
||||||
auto co1 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
|
auto co1 =
|
||||||
|
cast<mlir::NPCOMP::aten::ConstantOp>(operands[2].getDefiningOp());
|
||||||
DenseElementsAttr a1 =
|
DenseElementsAttr a1 =
|
||||||
co1.template getAttrOfType<DenseElementsAttr>("value");
|
co1.template getAttrOfType<DenseElementsAttr>("value");
|
||||||
for (auto i : a1.getAttributeValues())
|
for (auto i : a1.getAttributeValues())
|
||||||
|
@ -551,19 +543,21 @@ public:
|
||||||
|
|
||||||
APInt offset(32, 0);
|
APInt offset(32, 0);
|
||||||
if (operands.size() > 3) {
|
if (operands.size() > 3) {
|
||||||
auto co2 = cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
|
auto co2 =
|
||||||
|
cast<mlir::NPCOMP::aten::ConstantOp>(operands[3].getDefiningOp());
|
||||||
auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
|
auto ia2 = co2.getAttrOfType<IntegerAttr>("value");
|
||||||
offset = ia2.getValue();
|
offset = ia2.getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 8> callops{xVal, shape[0],
|
SmallVector<Value, 8> callops{
|
||||||
shape[1], shape[2],
|
xVal, shape[0],
|
||||||
shape[3], stride[0],
|
shape[1], shape[2],
|
||||||
stride[1], stride[2],
|
shape[3], stride[0],
|
||||||
stride[3], constInt(offset.getSExtValue(), 32)};
|
stride[1], stride[2],
|
||||||
|
stride[3], constInt(offset.getSExtValue(), 32)};
|
||||||
|
|
||||||
|
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "as_strided");
|
"as_strided");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -571,8 +565,8 @@ public:
|
||||||
class BatchNormOpConversion : public ConversionPattern {
|
class BatchNormOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit BatchNormOpConversion(MLIRContext *context)
|
explicit BatchNormOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(), 1,
|
: ConversionPattern(mlir::NPCOMP::aten::BatchNormOp::getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -585,8 +579,8 @@ public:
|
||||||
class ConvolutionOpConversion : public ConversionPattern {
|
class ConvolutionOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit ConvolutionOpConversion(MLIRContext *context)
|
explicit ConvolutionOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(), 1,
|
: ConversionPattern(mlir::NPCOMP::aten::ConvolutionOp::getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -614,8 +608,8 @@ public:
|
||||||
class DivOpConversion : public ConversionPattern {
|
class DivOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit DivOpConversion(MLIRContext *context)
|
explicit DivOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1, context) {
|
: ConversionPattern(mlir::NPCOMP::aten::DivOp::getOperationName(), 1,
|
||||||
}
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -627,8 +621,8 @@ public:
|
||||||
class LogSoftmaxOpConversion : public ConversionPattern {
|
class LogSoftmaxOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit LogSoftmaxOpConversion(MLIRContext *context)
|
explicit LogSoftmaxOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(), 1,
|
: ConversionPattern(mlir::NPCOMP::aten::LogSoftmaxOp::getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -657,8 +651,8 @@ public:
|
||||||
class MaxPoolOpConversion : public ConversionPattern {
|
class MaxPoolOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit MaxPoolOpConversion(MLIRContext *context)
|
explicit MaxPoolOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(), 1,
|
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dOp::getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -678,7 +672,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices");
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
"max_pool2d_with_indices");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -686,14 +681,15 @@ public:
|
||||||
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
|
class MaxPool2dWithIndicesBackwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
|
explicit MaxPool2dWithIndicesBackwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(
|
: ConversionPattern(mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::
|
||||||
mlir::NPCOMP::aten::MaxPool2dWithIndicesBackwardOp::getOperationName(), 1,
|
getOperationName(),
|
||||||
context) {}
|
1, context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "max_pool2d_with_indices_backward");
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
"max_pool2d_with_indices_backward");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -701,7 +697,8 @@ public:
|
||||||
class MMOpConversion : public ConversionPattern {
|
class MMOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit MMOpConversion(MLIRContext *context)
|
explicit MMOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1, context) {}
|
: ConversionPattern(mlir::NPCOMP::aten::MmOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -714,8 +711,8 @@ public:
|
||||||
class MulOpConversion : public ConversionPattern {
|
class MulOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit MulOpConversion(MLIRContext *context)
|
explicit MulOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1, context) {
|
: ConversionPattern(mlir::NPCOMP::aten::MulOp::getOperationName(), 1,
|
||||||
}
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -728,8 +725,9 @@ public:
|
||||||
class NativeBatchNormOpConversion : public ConversionPattern {
|
class NativeBatchNormOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit NativeBatchNormOpConversion(MLIRContext *context)
|
explicit NativeBatchNormOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, context) {}
|
mlir::NPCOMP::aten::NativeBatchNormOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -742,13 +740,15 @@ public:
|
||||||
class NllLoss2dBackwardOpConversion : public ConversionPattern {
|
class NllLoss2dBackwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
|
explicit NllLoss2dBackwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, context) {}
|
mlir::NPCOMP::aten::NllLoss2dBackwardOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_backward");
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
"nll_loss2d_backward");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -756,13 +756,15 @@ public:
|
||||||
class NllLoss2dForwardOpConversion : public ConversionPattern {
|
class NllLoss2dForwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit NllLoss2dForwardOpConversion(MLIRContext *context)
|
explicit NllLoss2dForwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, context) {}
|
mlir::NPCOMP::aten::NllLoss2dForwardOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss2d_forward");
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
"nll_loss2d_forward");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -770,8 +772,9 @@ public:
|
||||||
class NllLossBackwardOpConversion : public ConversionPattern {
|
class NllLossBackwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit NllLossBackwardOpConversion(MLIRContext *context)
|
explicit NllLossBackwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, context) {}
|
mlir::NPCOMP::aten::NllLossBackwardOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -784,13 +787,15 @@ public:
|
||||||
class NllLossForwardOpConversion : public ConversionPattern {
|
class NllLossForwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit NllLossForwardOpConversion(MLIRContext *context)
|
explicit NllLossForwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
|
: ConversionPattern(
|
||||||
context) {}
|
mlir::NPCOMP::aten::NllLossForwardOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward"); }
|
return rewriteWithFunctionCall(op, operands, rewriter, "nll_loss_forward");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Lower ReLU
|
/// Lower ReLU
|
||||||
|
@ -811,13 +816,15 @@ public:
|
||||||
class ThresholdBackwardOpConversion : public ConversionPattern {
|
class ThresholdBackwardOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit ThresholdBackwardOpConversion(MLIRContext *context)
|
explicit ThresholdBackwardOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, context) {}
|
mlir::NPCOMP::aten::ThresholdBackwardOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return rewriteWithFunctionCall(op, operands, rewriter, "threshold_backward");
|
return rewriteWithFunctionCall(op, operands, rewriter,
|
||||||
|
"threshold_backward");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -825,7 +832,8 @@ public:
|
||||||
class TransposeOpConversion : public ConversionPattern {
|
class TransposeOpConversion : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit TransposeOpConversion(MLIRContext *context)
|
explicit TransposeOpConversion(MLIRContext *context)
|
||||||
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1, context) {}
|
: ConversionPattern(mlir::NPCOMP::aten::TOp::getOperationName(), 1,
|
||||||
|
context) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
@ -849,21 +857,22 @@ public:
|
||||||
|
|
||||||
Value xVal = memRefTypeCast(rewriter, operands[0]);
|
Value xVal = memRefTypeCast(rewriter, operands[0]);
|
||||||
|
|
||||||
// construct the shape argument
|
// construct the shape argument
|
||||||
SmallVector<Value, 8> shape;
|
SmallVector<Value, 8> shape;
|
||||||
auto co = dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
auto co =
|
||||||
|
dyn_cast<mlir::NPCOMP::aten::ConstantOp>(operands[1].getDefiningOp());
|
||||||
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
|
DenseElementsAttr a = co.template getAttrOfType<DenseElementsAttr>("value");
|
||||||
for (auto i : a.getAttributeValues())
|
for (auto i : a.getAttributeValues())
|
||||||
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
|
shape.push_back(rewriter.create<mlir::ConstantOp>(co.getLoc(), i));
|
||||||
|
|
||||||
|
|
||||||
// pad out the shape with -1 to make it 4d
|
// pad out the shape with -1 to make it 4d
|
||||||
while (shape.size() < 4)
|
while (shape.size() < 4)
|
||||||
shape.push_back(constInt(-1, 32));
|
shape.push_back(constInt(-1, 32));
|
||||||
|
|
||||||
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
|
SmallVector<Value, 8> callops{xVal, shape[0], shape[1], shape[2], shape[3]};
|
||||||
|
|
||||||
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter, "view");
|
return rewriteWithVoidFunctionCallExplicit(op, callops, operands, rewriter,
|
||||||
|
"view");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -896,9 +905,8 @@ struct ATenLoweringPass
|
||||||
|
|
||||||
// c++ patterns
|
// c++ patterns
|
||||||
acapPatterns.insert<
|
acapPatterns.insert<
|
||||||
ConstantOpConversion,
|
ConstantOpConversion, AddOpConversion, ConvolutionOpConversion,
|
||||||
AddOpConversion, ConvolutionOpConversion, ReLUOpConversion,
|
ReLUOpConversion, TransposeOpConversion, BatchNormOpConversion,
|
||||||
TransposeOpConversion, BatchNormOpConversion,
|
|
||||||
NativeBatchNormOpConversion, MaxPoolOpConversion,
|
NativeBatchNormOpConversion, MaxPoolOpConversion,
|
||||||
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
|
MaxPool2dWithIndicesOpConversion, AddmmOpConversion, ViewOpConversion,
|
||||||
MulOpConversion, MMOpConversion, AsStridedOpConversion,
|
MulOpConversion, MMOpConversion, AsStridedOpConversion,
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenToStd.h"
|
#include "npcomp/Dialect/ATen/ATenToStd.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
|
@ -115,8 +115,8 @@ std::string LivenessReport::emitJSONReport() {
|
||||||
for (auto v : vlist) {
|
for (auto v : vlist) {
|
||||||
int64_t vol = getTensorVolume(v.getType());
|
int64_t vol = getTensorVolume(v.getType());
|
||||||
if (v.getDefiningOp()) {
|
if (v.getDefiningOp()) {
|
||||||
if (auto a = v.getDefiningOp()->getAttrOfType<StringAttr>(
|
if (auto a =
|
||||||
"layer_name")) {
|
v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
|
||||||
auto definingOp = v.getDefiningOp();
|
auto definingOp = v.getDefiningOp();
|
||||||
auto ld = layerDetail.getInteger(a.getValue().str());
|
auto ld = layerDetail.getInteger(a.getValue().str());
|
||||||
if (ld)
|
if (ld)
|
||||||
|
|
|
@ -32,28 +32,29 @@ public:
|
||||||
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
|
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
|
||||||
return op.getOperation();
|
return op.getOperation();
|
||||||
})
|
})
|
||||||
.def("scf_if_op",
|
.def(
|
||||||
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
"scf_if_op",
|
||||||
PyValue cond, bool withElseRegion) {
|
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
PyValue cond, bool withElseRegion) {
|
||||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||||
pyResultTypes.end());
|
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
||||||
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
pyResultTypes.end());
|
||||||
withElseRegion);
|
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
||||||
if (withElseRegion) {
|
withElseRegion);
|
||||||
return py::make_tuple(
|
if (withElseRegion) {
|
||||||
PyOperationRef(op),
|
return py::make_tuple(
|
||||||
op.getThenBodyBuilder().saveInsertionPoint(),
|
PyOperationRef(op),
|
||||||
op.getElseBodyBuilder().saveInsertionPoint());
|
op.getThenBodyBuilder().saveInsertionPoint(),
|
||||||
} else {
|
op.getElseBodyBuilder().saveInsertionPoint());
|
||||||
return py::make_tuple(
|
} else {
|
||||||
PyOperationRef(op),
|
return py::make_tuple(
|
||||||
op.getThenBodyBuilder().saveInsertionPoint());
|
PyOperationRef(op),
|
||||||
}
|
op.getThenBodyBuilder().saveInsertionPoint());
|
||||||
},
|
}
|
||||||
py::arg("result_types"), py::arg("cond"),
|
},
|
||||||
py::arg("with_else_region") = false);
|
py::arg("result_types"), py::arg("cond"),
|
||||||
|
py::arg("with_else_region") = false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -65,13 +65,14 @@ void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
|
||||||
"front",
|
"front",
|
||||||
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
|
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
|
||||||
.def("__len__", [](ThisTy &self) { return self.list.size(); })
|
.def("__len__", [](ThisTy &self) { return self.list.size(); })
|
||||||
.def("__iter__",
|
.def(
|
||||||
[](ThisTy &self) {
|
"__iter__",
|
||||||
PyItemIterator begin(self.list.begin());
|
[](ThisTy &self) {
|
||||||
PyItemIterator end(self.list.end());
|
PyItemIterator begin(self.list.begin());
|
||||||
return py::make_iterator(begin, end);
|
PyItemIterator end(self.list.end());
|
||||||
},
|
return py::make_iterator(begin, end);
|
||||||
py::keep_alive<0, 1>());
|
},
|
||||||
|
py::keep_alive<0, 1>());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -194,79 +195,81 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
||||||
return self.context.shared_from_this();
|
return self.context.shared_from_this();
|
||||||
})
|
})
|
||||||
.def("op",
|
.def(
|
||||||
[](PyDialectHelper &self, const std::string &opNameStr,
|
"op",
|
||||||
std::vector<PyType> pyResultTypes,
|
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||||
std::vector<PyValue> pyOperands,
|
std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
|
||||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
||||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||||
OperationName opName(opNameStr, opBuilder.getContext());
|
OperationName opName(opNameStr, opBuilder.getContext());
|
||||||
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
||||||
pyResultTypes.end());
|
pyResultTypes.end());
|
||||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||||
pyOperands.end());
|
pyOperands.end());
|
||||||
MutableDictionaryAttr attrList;
|
MutableDictionaryAttr attrList;
|
||||||
if (attrs) {
|
if (attrs) {
|
||||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||||
if (!dictAttrs) {
|
if (!dictAttrs) {
|
||||||
throw py::raiseValueError(
|
throw py::raiseValueError(
|
||||||
"Expected `attrs` to be a DictionaryAttr");
|
"Expected `attrs` to be a DictionaryAttr");
|
||||||
}
|
}
|
||||||
attrList = MutableDictionaryAttr(dictAttrs);
|
attrList = MutableDictionaryAttr(dictAttrs);
|
||||||
}
|
}
|
||||||
Operation *op =
|
Operation *op =
|
||||||
Operation::create(loc, opName, types, operands, attrList);
|
Operation::create(loc, opName, types, operands, attrList);
|
||||||
opBuilder.insert(op);
|
opBuilder.insert(op);
|
||||||
return op;
|
return op;
|
||||||
},
|
},
|
||||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||||
.def("func_op",
|
.def(
|
||||||
[](PyDialectHelper &self, const std::string &name, PyType type,
|
"func_op",
|
||||||
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
|
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
|
||||||
if (!functionType) {
|
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||||
throw py::raiseValueError("Illegal function type");
|
if (!functionType) {
|
||||||
}
|
throw py::raiseValueError("Illegal function type");
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
}
|
||||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
// TODO: Dedup attr creation from op().
|
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||||
MutableDictionaryAttr attrList;
|
// TODO: Dedup attr creation from op().
|
||||||
if (attrs) {
|
MutableDictionaryAttr attrList;
|
||||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
if (attrs) {
|
||||||
if (!dictAttrs) {
|
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||||
throw py::raiseValueError(
|
if (!dictAttrs) {
|
||||||
"Expected `attrs` to be a DictionaryAttr");
|
throw py::raiseValueError(
|
||||||
}
|
"Expected `attrs` to be a DictionaryAttr");
|
||||||
attrList = MutableDictionaryAttr(dictAttrs);
|
}
|
||||||
}
|
attrList = MutableDictionaryAttr(dictAttrs);
|
||||||
FuncOp op =
|
}
|
||||||
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
FuncOp op =
|
||||||
/*attrs=*/attrList.getAttrs());
|
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||||
if (createEntryBlock) {
|
/*attrs=*/attrList.getAttrs());
|
||||||
Block *entryBlock = new Block();
|
if (createEntryBlock) {
|
||||||
entryBlock->addArguments(functionType.getInputs());
|
Block *entryBlock = new Block();
|
||||||
op.getBody().push_back(entryBlock);
|
entryBlock->addArguments(functionType.getInputs());
|
||||||
opBuilder.setInsertionPointToStart(entryBlock);
|
op.getBody().push_back(entryBlock);
|
||||||
}
|
opBuilder.setInsertionPointToStart(entryBlock);
|
||||||
return PyOperationRef(op);
|
}
|
||||||
},
|
return PyOperationRef(op);
|
||||||
py::arg("name"), py::arg("type"),
|
},
|
||||||
py::arg("create_entry_block") = false,
|
py::arg("name"), py::arg("type"),
|
||||||
py::arg("attrs") = llvm::Optional<PyAttribute>(),
|
py::arg("create_entry_block") = false,
|
||||||
R"(Creates a new `func` op, optionally creating an entry block.
|
py::arg("attrs") = llvm::Optional<PyAttribute>(),
|
||||||
|
R"(Creates a new `func` op, optionally creating an entry block.
|
||||||
If an entry block is created, the builder will be positioned
|
If an entry block is created, the builder will be positioned
|
||||||
to its start.)")
|
to its start.)")
|
||||||
.def("select_op",
|
.def(
|
||||||
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
"select_op",
|
||||||
PyValue falseValue) -> PyOperationRef {
|
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
PyValue falseValue) -> PyOperationRef {
|
||||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
return PyOperationRef(opBuilder.create<SelectOp>(
|
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||||
loc, conditionValue, trueValue, falseValue));
|
return PyOperationRef(opBuilder.create<SelectOp>(
|
||||||
},
|
loc, conditionValue, trueValue, falseValue));
|
||||||
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
},
|
||||||
|
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
||||||
.def("return_op",
|
.def("return_op",
|
||||||
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
|
@ -288,11 +291,12 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
[](PyDialectHelper &self) -> PyType {
|
[](PyDialectHelper &self) -> PyType {
|
||||||
return IndexType::get(self.getContext());
|
return IndexType::get(self.getContext());
|
||||||
})
|
})
|
||||||
.def("integer_type",
|
.def(
|
||||||
[](PyDialectHelper &self, unsigned width) -> PyType {
|
"integer_type",
|
||||||
return IntegerType::get(width, self.getContext());
|
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||||
},
|
return IntegerType::get(width, self.getContext());
|
||||||
py::arg("width") = 32)
|
},
|
||||||
|
py::arg("width") = 32)
|
||||||
.def_property_readonly("i1_type",
|
.def_property_readonly("i1_type",
|
||||||
[](PyDialectHelper &self) -> PyType {
|
[](PyDialectHelper &self) -> PyType {
|
||||||
return IntegerType::get(1, self.getContext());
|
return IntegerType::get(1, self.getContext());
|
||||||
|
@ -319,20 +323,21 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
return FloatType::get(StandardTypes::F64,
|
return FloatType::get(StandardTypes::F64,
|
||||||
self.getContext());
|
self.getContext());
|
||||||
})
|
})
|
||||||
.def("tensor_type",
|
.def(
|
||||||
[](PyDialectHelper &self, PyType elementType,
|
"tensor_type",
|
||||||
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
[](PyDialectHelper &self, PyType elementType,
|
||||||
if (!elementType.type) {
|
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||||
throw py::raiseValueError("Null element type");
|
if (!elementType.type) {
|
||||||
}
|
throw py::raiseValueError("Null element type");
|
||||||
if (shape) {
|
}
|
||||||
return RankedTensorType::get(*shape, elementType.type);
|
if (shape) {
|
||||||
} else {
|
return RankedTensorType::get(*shape, elementType.type);
|
||||||
return UnrankedTensorType::get(elementType.type);
|
} else {
|
||||||
}
|
return UnrankedTensorType::get(elementType.type);
|
||||||
},
|
}
|
||||||
py::arg("element_type"),
|
},
|
||||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
py::arg("element_type"),
|
||||||
|
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||||
.def("function_type",
|
.def("function_type",
|
||||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||||
std::vector<PyType> results) -> PyType {
|
std::vector<PyType> results) -> PyType {
|
||||||
|
@ -367,21 +372,24 @@ void defineMlirIrModule(py::module m) {
|
||||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||||
|
|
||||||
// Globals.
|
// Globals.
|
||||||
m.def("emit_error",
|
m.def(
|
||||||
[](PyAttribute loc, std::string message) {
|
"emit_error",
|
||||||
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
|
[](PyAttribute loc, std::string message) {
|
||||||
},
|
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
|
||||||
py::arg("loc"), py::arg("message"));
|
},
|
||||||
m.def("emit_warning",
|
py::arg("loc"), py::arg("message"));
|
||||||
[](PyAttribute loc, std::string message) {
|
m.def(
|
||||||
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
|
"emit_warning",
|
||||||
},
|
[](PyAttribute loc, std::string message) {
|
||||||
py::arg("loc"), py::arg("message"));
|
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
|
||||||
m.def("emit_remark",
|
},
|
||||||
[](PyAttribute loc, std::string message) {
|
py::arg("loc"), py::arg("message"));
|
||||||
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
|
m.def(
|
||||||
},
|
"emit_remark",
|
||||||
py::arg("loc"), py::arg("message"));
|
[](PyAttribute loc, std::string message) {
|
||||||
|
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
|
||||||
|
},
|
||||||
|
py::arg("loc"), py::arg("message"));
|
||||||
|
|
||||||
// Python only types.
|
// Python only types.
|
||||||
PyDialectHelper::bind(m);
|
PyDialectHelper::bind(m);
|
||||||
|
@ -426,25 +434,27 @@ void PyContext::bind(py::module m) {
|
||||||
return PyModuleOp(self.shared_from_this(), m);
|
return PyModuleOp(self.shared_from_this(), m);
|
||||||
})
|
})
|
||||||
.def("parse_asm", &PyContext::parseAsm)
|
.def("parse_asm", &PyContext::parseAsm)
|
||||||
.def("new_builder",
|
.def(
|
||||||
[](PyContext &self) {
|
"new_builder",
|
||||||
// Note: we collapse the Builder and OpBuilder into one because
|
[](PyContext &self) {
|
||||||
// there is little reason to expose the inheritance hierarchy to
|
// Note: we collapse the Builder and OpBuilder into one because
|
||||||
// Python.
|
// there is little reason to expose the inheritance hierarchy to
|
||||||
return PyOpBuilder(self);
|
// Python.
|
||||||
},
|
return PyOpBuilder(self);
|
||||||
py::keep_alive<0, 1>())
|
},
|
||||||
|
py::keep_alive<0, 1>())
|
||||||
.def("identifier",
|
.def("identifier",
|
||||||
[](PyContext &self, std::string s) -> PyIdentifier {
|
[](PyContext &self, std::string s) -> PyIdentifier {
|
||||||
return Identifier::get(s, &self.context);
|
return Identifier::get(s, &self.context);
|
||||||
})
|
})
|
||||||
.def("file_line_col_loc_attr",
|
.def(
|
||||||
[](PyContext &self, PyIdentifier filename, unsigned line,
|
"file_line_col_loc_attr",
|
||||||
unsigned column) -> PyAttribute {
|
[](PyContext &self, PyIdentifier filename, unsigned line,
|
||||||
return static_cast<LocationAttr>(FileLineColLoc::get(
|
unsigned column) -> PyAttribute {
|
||||||
filename.identifier, line, column, &self.context));
|
return static_cast<LocationAttr>(FileLineColLoc::get(
|
||||||
},
|
filename.identifier, line, column, &self.context));
|
||||||
py::arg("filename"), py::arg("line"), py::arg("column"))
|
},
|
||||||
|
py::arg("filename"), py::arg("line"), py::arg("column"))
|
||||||
// Salient functions from Builder.
|
// Salient functions from Builder.
|
||||||
.def("parse_type",
|
.def("parse_type",
|
||||||
[](PyContext &self, const std::string &asmText) {
|
[](PyContext &self, const std::string &asmText) {
|
||||||
|
@ -456,14 +466,15 @@ void PyContext::bind(py::module m) {
|
||||||
}
|
}
|
||||||
return PyType(t);
|
return PyType(t);
|
||||||
})
|
})
|
||||||
.def("integer_attr",
|
.def(
|
||||||
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
|
"integer_attr",
|
||||||
if (!type.type.isa<IntegerType>()) {
|
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
|
||||||
throw py::raiseValueError("Expected IntegerType");
|
if (!type.type.isa<IntegerType>()) {
|
||||||
}
|
throw py::raiseValueError("Expected IntegerType");
|
||||||
return IntegerAttr::get(type.type, value);
|
}
|
||||||
},
|
return IntegerAttr::get(type.type, value);
|
||||||
py::arg("type"), py::arg("value"))
|
},
|
||||||
|
py::arg("type"), py::arg("value"))
|
||||||
.def("float_attr",
|
.def("float_attr",
|
||||||
[](PyContext &self, PyType type, double value) -> PyAttribute {
|
[](PyContext &self, PyType type, double value) -> PyAttribute {
|
||||||
if (!type.type.isa<FloatType>()) {
|
if (!type.type.isa<FloatType>()) {
|
||||||
|
@ -512,20 +523,20 @@ void PyContext::bind(py::module m) {
|
||||||
}
|
}
|
||||||
return ArrayAttr::get(attrs, &self.context);
|
return ArrayAttr::get(attrs, &self.context);
|
||||||
})
|
})
|
||||||
.def("dense_elements_attr",
|
.def(
|
||||||
[](PyContext &self, py::buffer array) -> PyAttribute {
|
"dense_elements_attr",
|
||||||
// Request a contiguous view.
|
[](PyContext &self, py::buffer array) -> PyAttribute {
|
||||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
// Request a contiguous view.
|
||||||
Py_buffer *view = new Py_buffer();
|
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||||
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
Py_buffer *view = new Py_buffer();
|
||||||
delete view;
|
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
||||||
throw py::error_already_set();
|
delete view;
|
||||||
}
|
throw py::error_already_set();
|
||||||
py::buffer_info array_info(view);
|
}
|
||||||
return createDenseElementsAttrFromBuffer(&self.context,
|
py::buffer_info array_info(view);
|
||||||
array_info);
|
return createDenseElementsAttrFromBuffer(&self.context, array_info);
|
||||||
},
|
},
|
||||||
py::arg("array"))
|
py::arg("array"))
|
||||||
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
|
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
|
||||||
return UnitAttr::get(&self.context);
|
return UnitAttr::get(&self.context);
|
||||||
});
|
});
|
||||||
|
@ -905,67 +916,74 @@ void PyBaseOpBuilder::bind(py::module m) {
|
||||||
void PyOpBuilder::bind(py::module m) {
|
void PyOpBuilder::bind(py::module m) {
|
||||||
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
||||||
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
|
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
|
||||||
.def_property("current_loc",
|
.def_property(
|
||||||
[](PyOpBuilder &self) -> PyAttribute {
|
"current_loc",
|
||||||
return static_cast<Attribute>(self.getCurrentLoc());
|
[](PyOpBuilder &self) -> PyAttribute {
|
||||||
},
|
return static_cast<Attribute>(self.getCurrentLoc());
|
||||||
[](PyOpBuilder &self, PyAttribute attr) {
|
},
|
||||||
auto loc_attr =
|
[](PyOpBuilder &self, PyAttribute attr) {
|
||||||
attr.attr.dyn_cast_or_null<LocationAttr>();
|
auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
|
||||||
if (!loc_attr) {
|
if (!loc_attr) {
|
||||||
throw py::raiseValueError("Expected a LocationAttr");
|
throw py::raiseValueError("Expected a LocationAttr");
|
||||||
}
|
}
|
||||||
self.setCurrentLoc(Location(loc_attr));
|
self.setCurrentLoc(Location(loc_attr));
|
||||||
})
|
})
|
||||||
.def_property("insertion_point",
|
.def_property(
|
||||||
[](PyOpBuilder &self) {
|
"insertion_point",
|
||||||
return self.getBuilder(true).saveInsertionPoint();
|
[](PyOpBuilder &self) {
|
||||||
},
|
return self.getBuilder(true).saveInsertionPoint();
|
||||||
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
},
|
||||||
self.getBuilder(false).restoreInsertionPoint(ip);
|
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
||||||
})
|
self.getBuilder(false).restoreInsertionPoint(ip);
|
||||||
.def("set_file_line_col",
|
})
|
||||||
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
.def(
|
||||||
unsigned column) {
|
"set_file_line_col",
|
||||||
Location loc = FileLineColLoc::get(filename.identifier, line,
|
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
||||||
column, self.getContext());
|
unsigned column) {
|
||||||
self.setCurrentLoc(loc);
|
Location loc = FileLineColLoc::get(filename.identifier, line,
|
||||||
},
|
column, self.getContext());
|
||||||
py::arg("filename"), py::arg("line"), py::arg("column"),
|
self.setCurrentLoc(loc);
|
||||||
"Shortcut to set a FileLineCol current location")
|
},
|
||||||
|
py::arg("filename"), py::arg("line"), py::arg("column"),
|
||||||
|
"Shortcut to set a FileLineCol current location")
|
||||||
.def("clear_insertion_point",
|
.def("clear_insertion_point",
|
||||||
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
|
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
|
||||||
.def("insert_op_before",
|
.def(
|
||||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
"insert_op_before",
|
||||||
Operation *op = pyOp.getOperation();
|
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||||
self.builder.setInsertionPoint(op);
|
Operation *op = pyOp.getOperation();
|
||||||
},
|
self.builder.setInsertionPoint(op);
|
||||||
"Sets the insertion point to just before the specified op.")
|
},
|
||||||
.def("insert_op_after",
|
"Sets the insertion point to just before the specified op.")
|
||||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
.def(
|
||||||
Operation *op = pyOp.getOperation();
|
"insert_op_after",
|
||||||
self.builder.setInsertionPointAfter(op);
|
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||||
},
|
Operation *op = pyOp.getOperation();
|
||||||
"Sets the insertion point to just after the specified op.")
|
self.builder.setInsertionPointAfter(op);
|
||||||
.def("insert_block_start",
|
},
|
||||||
[](PyOpBuilder &self, PyBlockRef block) {
|
"Sets the insertion point to just after the specified op.")
|
||||||
self.builder.setInsertionPointToStart(&block.block);
|
.def(
|
||||||
},
|
"insert_block_start",
|
||||||
"Sets the insertion point to the start of the block.")
|
[](PyOpBuilder &self, PyBlockRef block) {
|
||||||
.def("insert_block_end",
|
self.builder.setInsertionPointToStart(&block.block);
|
||||||
[](PyOpBuilder &self, PyBlockRef block) {
|
},
|
||||||
self.builder.setInsertionPointToEnd(&block.block);
|
"Sets the insertion point to the start of the block.")
|
||||||
},
|
.def(
|
||||||
"Sets the insertion point to the end of the block.")
|
"insert_block_end",
|
||||||
.def("insert_before_terminator",
|
[](PyOpBuilder &self, PyBlockRef block) {
|
||||||
[](PyOpBuilder &self, PyBlockRef block) {
|
self.builder.setInsertionPointToEnd(&block.block);
|
||||||
auto *terminator = block.block.getTerminator();
|
},
|
||||||
if (!terminator) {
|
"Sets the insertion point to the end of the block.")
|
||||||
throw py::raiseValueError("Block has no terminator");
|
.def(
|
||||||
}
|
"insert_before_terminator",
|
||||||
self.builder.setInsertionPoint(terminator);
|
[](PyOpBuilder &self, PyBlockRef block) {
|
||||||
},
|
auto *terminator = block.block.getTerminator();
|
||||||
"Sets the insertion point to just before the block terminator.");
|
if (!terminator) {
|
||||||
|
throw py::raiseValueError("Block has no terminator");
|
||||||
|
}
|
||||||
|
self.builder.setInsertionPoint(terminator);
|
||||||
|
},
|
||||||
|
"Sets the insertion point to just before the block terminator.");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -32,13 +32,14 @@ void PyPassManager::bind(py::module m) {
|
||||||
py::class_<PyPassManager>(m, "PassManager")
|
py::class_<PyPassManager>(m, "PassManager")
|
||||||
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
||||||
py::arg("verifyModules") = true)
|
py::arg("verifyModules") = true)
|
||||||
.def("enableCrashReproducerGeneration",
|
.def(
|
||||||
[](PyPassManager &self, std::string outputFile,
|
"enableCrashReproducerGeneration",
|
||||||
bool genLocalReproducer) {
|
[](PyPassManager &self, std::string outputFile,
|
||||||
self.passManager.enableCrashReproducerGeneration(
|
bool genLocalReproducer) {
|
||||||
outputFile, genLocalReproducer);
|
self.passManager.enableCrashReproducerGeneration(
|
||||||
},
|
outputFile, genLocalReproducer);
|
||||||
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
},
|
||||||
|
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||||
.def("__len__",
|
.def("__len__",
|
||||||
[](PyPassManager &self) { return self.passManager.size(); })
|
[](PyPassManager &self) { return self.passManager.size(); })
|
||||||
.def("__str__",
|
.def("__str__",
|
||||||
|
|
|
@ -48,18 +48,19 @@ public:
|
||||||
return Basicpy::NoneType::get(
|
return Basicpy::NoneType::get(
|
||||||
self.getContext());
|
self.getContext());
|
||||||
})
|
})
|
||||||
.def("basicpy_SlotObject_type",
|
.def(
|
||||||
[](BasicpyDialectHelper &self, std::string className,
|
"basicpy_SlotObject_type",
|
||||||
py::args pySlotTypes) -> PyType {
|
[](BasicpyDialectHelper &self, std::string className,
|
||||||
SmallVector<Type, 4> slotTypes;
|
py::args pySlotTypes) -> PyType {
|
||||||
for (auto pySlotType : pySlotTypes) {
|
SmallVector<Type, 4> slotTypes;
|
||||||
slotTypes.push_back(pySlotType.cast<PyType>());
|
for (auto pySlotType : pySlotTypes) {
|
||||||
}
|
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||||
auto classNameAttr =
|
}
|
||||||
StringAttr::get(className, self.getContext());
|
auto classNameAttr =
|
||||||
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
StringAttr::get(className, self.getContext());
|
||||||
},
|
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||||
py::arg("className"))
|
},
|
||||||
|
py::arg("className"))
|
||||||
.def_property_readonly("basicpy_StrType",
|
.def_property_readonly("basicpy_StrType",
|
||||||
[](BasicpyDialectHelper &self) -> PyType {
|
[](BasicpyDialectHelper &self) -> PyType {
|
||||||
return Basicpy::StrType::get(
|
return Basicpy::StrType::get(
|
||||||
|
|
Loading…
Reference in New Issue