torch-mlir/frontends/pytorch/csrc/jit.cpp

334 lines
9.4 KiB
C++
Raw Normal View History

Add pytorch interface to ATen Dialect (#30) This patch adds a pytorch interface to npcomp. This interface is modeled after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar to a gpu device or the xla backend). Usage is intended to be something like: dev = torch_mlir.mlir_device() t0 = torch.randn((4,4), device=dev) t1 = torch.randn((4,4), device=dev) t2 = t0 + t1 t2_mlir = torch_mlir.get_mlir( t2 ) t2_cpu = t2.to('cpu') In this case t2_cpu would contain the result of the computation, and t2_mlir contains the mlir description of the computation. Note that this also properly returns backward paths synthesized by pytorch. There are several parts of this: 1) A tensor type (implemented by tensor.* and tensor_impl.*) 2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*) 3) a temporary IR (implemented by ir.cpp) There is also a reference lowering directly from the ATen dialect to C function calls consisting of two parts: 1) The driver that uses the IR to generate MLIR, run Passes and compile the result using mlir::ExecutionEngine (implemented by jit.cpp and mlir_gen.cpp) 2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations are implemented by callbacks into the torch C++ libraries. Some aspects of this are known to be less than optimal, in particular: 1) There's some function definitions that don't live in the file corresponding to their declaration. 2) More aspects of this (e.g. the IR) seem like they should be automatically generated. 3) It's unclear to me how much of the 'IR' is actually necessary, or whether MLIR could be created on the fly. Note that this code is licensed in a way similar to pytorch, with the intention that eventually (when npcomp reaches some maturity) it should be pushed there. (see frontends/pytorch/LICENSE) The code is also structured much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
//===- jit.cpp --------------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
// This file drives the generation and lowering of MLIR, followed by JIT
// compiling the resulting LLVM dialect.
#include "npcomp/Dialect/ATen/ATenDialect.h"
#include "npcomp/Dialect/ATen/ATenPasses.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <dlfcn.h>
#include "ATen/ArrayRef.h"
namespace at {
template <typename T> using ArrayRef = c10::ArrayRef<T>;
}
#include "ATen/Tensor.h"
#include <ATen/CPUType.h>
#include "jit.h"
#include "mlir_gen.h"
#include "tensor.h"
#include "torch_util.h"
#define DEBUG_TYPE "torch_mlir"
using namespace mlir;
namespace torch_mlir {
namespace {
int LowerATenDialect(mlir::ModuleOp module) {
PassManager pm0(module.getContext());
pm0.addPass(mlir::createCSEPass());
// Lower to function calls.
pm0.addPass(mlir::NPCOMP::aten::createATenLoweringPass());
pm0.addPass(mlir::NPCOMP::aten::createReturnEliminationPass());
if (failed(pm0.run(module))) {
llvm::errs() << "aten to loops conversion failed ";
return 1;
}
PassManager pm1(module.getContext());
pm1.addPass(mlir::createLowerAffinePass());
pm1.addPass(mlir::createLowerToCFGPass());
pm1.addPass(mlir::createCSEPass());
if (failed(pm1.run(module))) {
llvm::errs() << "loops to std conversion failed ";
return 1;
}
return 0;
}
int LowerStdDialect(mlir::ModuleOp module) {
PassManager pm(module.getContext());
struct LowerToLLVMOptions options;
options.emitCWrappers = true;
LLVM_DEBUG(module.print(llvm::outs()));
pm.addPass(mlir::createLowerToLLVMPass(options));
pm.addPass(mlir::createCSEPass());
LLVM_DEBUG(module.print(llvm::outs()));
if (failed(pm.run(module))) {
llvm::errs() << "std to llvm conversion failed ";
return 1;
}
if (!module)
return 1;
return 0;
}
template <typename T, int N> struct llvm_tensor_t {
T *d;
T *aligned;
size_t offset;
size_t shape[N];
size_t stride[N];
};
template <typename T, int N> void *setupArg(at::Tensor &t) {
llvm_tensor_t<T, N> *arg = new llvm_tensor_t<T, N>;
llvm_tensor_t<T, N> **arg_storage = new llvm_tensor_t<T, N> *;
*arg_storage = arg;
arg->d = arg->aligned = (T *)t.data_ptr();
arg->offset = 0;
assert(t.dim() == N);
for (int j = 0; j < N; j++) {
arg->shape[j] = t.sizes()[j];
arg->stride[j] = t.stride(j);
}
return (void *)arg_storage;
}
at::Tensor LowerAndRun(mlir::ModuleOp module,
std::vector<at::Tensor> &arguments, const ir::Value &v,
mlir::MLIRContext &context) {
LowerATenDialect(module);
LowerStdDialect(module);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel =
llvm::CodeGenOpt::Level::Aggressive;
std::string libpath;
if (const char *path = std::getenv("TEST_BUILD_PATH")) {
libpath = path;
}
std::vector<std::string> sharedLibs{libpath +
"/frontends/pytorch/lib/libaten_ops.so"};
llvm::errs() << "Loading " << sharedLibs[0] << "\n";
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
llvm::SmallVector<llvm::StringRef, 1> libs(sharedLibs.begin(),
sharedLibs.end());
auto expectedEngine = mlir::ExecutionEngine::create(
module, {}, jitCodeGenOptLevel, libs, false, false, false);
assert(expectedEngine && "no engine, cannot fly");
llvm::StringRef entryPoint("_mlir_ciface_graph");
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
assert(expectedFPtr && "entryPoint missing");
void (*fptr)(void **) = *expectedFPtr;
// this array holds pointers to the function arguments
void **args = (void **)malloc((arguments.size() + 1) * sizeof(void *));
// allocate and setup the function arguments
for (int i = 0, e = arguments.size(); i < e; i++) {
at::Tensor &t = arguments[i];
auto dtype = t.dtype();
int dim = t.dim();
if (dim == 4) {
if (dtype == at::kFloat)
args[i] = setupArg<float, 4>(t);
else if (dtype == at::kLong)
args[i] = setupArg<uint64_t, 4>(t);
else
assert(0);
} else if (dim == 3) {
if (dtype == at::kFloat)
args[i] = setupArg<float, 3>(t);
else if (dtype == at::kLong)
args[i] = setupArg<uint64_t, 3>(t);
else
assert(0);
} else if (dim == 2) {
if (dtype == at::kFloat)
args[i] = setupArg<float, 2>(t);
else if (dtype == at::kLong)
args[i] = setupArg<uint64_t, 2>(t);
else
assert(0);
} else if (dim == 1) {
if (dtype == at::kFloat)
args[i] = setupArg<float, 1>(t);
else if (dtype == at::kLong)
args[i] = setupArg<uint64_t, 1>(t);
else
assert(0);
} else {
assert(0 && "unhandled dim");
}
}
// allocate the result tensors
// TODO: num results > 1
at::Tensor result = util::Zeros(v.sizes(), at::kFloat);
if (result.dim() == 4) {
args[arguments.size()] = setupArg<float, 4>(result);
} else if (result.dim() == 3) {
args[arguments.size()] = setupArg<float, 3>(result);
} else if (result.dim() == 2) {
args[arguments.size()] = setupArg<float, 2>(result);
} else if (result.dim() == 1) {
args[arguments.size()] = setupArg<float, 1>(result);
} else {
assert(0 && "unhandled dim");
}
// call the JITed function
fptr(args);
// free pointers to the results
// TODO: num results > 1
if (result.dim() == 4) {
auto arg_storage =
static_cast<llvm_tensor_t<float, 4> **>(args[arguments.size()]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (result.dim() == 3) {
auto arg_storage =
static_cast<llvm_tensor_t<float, 3> **>(args[arguments.size()]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (result.dim() == 2) {
auto arg_storage =
static_cast<llvm_tensor_t<float, 2> **>(args[arguments.size()]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (result.dim() == 1) {
auto arg_storage =
static_cast<llvm_tensor_t<float, 1> **>(args[arguments.size()]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else {
assert(0 && "unhandled dim");
}
// free pointers to the arguments
for (int i = 0, e = arguments.size(); i < e; i++) {
at::Tensor &t = arguments[i];
int dim = t.dim();
if (dim == 4) {
auto arg_storage = static_cast<llvm_tensor_t<float, 4> **>(args[i]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (dim == 3) {
auto arg_storage = static_cast<llvm_tensor_t<float, 3> **>(args[i]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (dim == 2) {
auto arg_storage = static_cast<llvm_tensor_t<float, 2> **>(args[i]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else if (dim == 1) {
auto arg_storage = static_cast<llvm_tensor_t<float, 1> **>(args[i]);
auto arg = *arg_storage;
delete arg;
delete arg_storage;
} else {
assert(0 && "unhandled dim");
}
}
// free the array of void* ptrs
free(args);
return result;
}
at::Tensor JitAndRun(const ir::Value &v, mlir::MLIRContext &context) {
// generate the MLIR
std::vector<ir::Value> vs{v};
auto mlir_gen = MLIRGen(context).genModule(vs);
mlir::OwningModuleRef module = std::move(std::get<0>(mlir_gen));
std::vector<at::Tensor> arguments = std::move(std::get<1>(mlir_gen));
return LowerAndRun(module.get(), arguments, v, context);
}
at::Tensor JitAndRun(const ir::Value &v) {
mlir::MLIRContext context;
return JitAndRun(v, context);
}
at::Tensor Interpret(const ir::Value &v) { assert(0 && "unsupported"); }
} // anonymous namespace
// FIXME: Why is this code here and not in tensor.cpp?
std::string MLIRTensor::GetMLIR() const {
// generate the MLIR
mlir::MLIRContext context;
ir::Value ir_value = CurrentIrValue();
if (!ir_value)
return "<tensor>";
std::vector<ir::Value> vs{ir_value};
auto mlir_gen = MLIRGen(context).genModule(vs);
mlir::OwningModuleRef module = std::move(std::get<0>(mlir_gen));
std::string aten;
llvm::raw_string_ostream ss(aten);
module->print(ss);
return ss.str();
}
at::Tensor MLIRTensor::CompileAndRun() const {
return JitAndRun(CurrentIrValue());
}
} // namespace torch_mlir