torch-mlir/frontends/pytorch/csrc/builder/module_builder.h

74 lines
2.3 KiB
C
Raw Normal View History

//===- module_builder.h -----------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_H
#include "../pybind.h"
#include "acap_dispatch.h"
#include "class_annotator.h"
#include "mlir-c/IR.h"
#include <ATen/Tensor.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for constructing an MLIR module from some combination
/// of PyTorch programs/execution.
class ModuleBuilder {
public:
ModuleBuilder(pybind11::object contextObj);
/// Creates Python bindings for the class.
static void bind(pybind11::module &m);
pybind11::object getContextObj() { return contextObj; }
pybind11::object getModuleObj() { return moduleObj; }
// Starts a device-capture based function.
std::shared_ptr<AcapController>
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
// Imports a traced function. Note that the python type
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
// Just a bit of naming cruft.
// Returns the same function, making it suitable as a nested decorator.
torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function);
// Imports a torch::jit::Module into the current module, using the
// annotations, if not none, provided in `maybeClassAnnotator` which should be
// a ClassAnnotator.
void importModule(torch::jit::Module jitModule,
py::object maybeClassAnnotator);
private:
FuncBuilder::Inserter createInserter();
MlirBlock getBodyBlock();
// Capture references to the python-owned context and module. Ownership
// is delegated to python for these, and the C-API types are extracted via
// the capsule API.
pybind11::object contextObj;
MlirContext context;
MlirModule module;
pybind11::object moduleObj;
MlirOperation terminator;
MlirLocation unknownLoc;
TypeMapper typeMapper;
};
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H