mirror of https://github.com/llvm/torch-mlir
Remove op_builder.h/cpp
parent
900f0e04aa
commit
6f710bbc47
|
@ -17,7 +17,6 @@ add_library(TorchMLIRTorchPlugin SHARED
|
|||
builder/function_importer.cpp
|
||||
builder/module_builder.cpp
|
||||
builder/node_importer.cpp
|
||||
builder/op_builder.cpp
|
||||
builder/ivalue_importer.cpp
|
||||
builder/python_bindings.cpp
|
||||
builder/torch_to_mlir_utils.cpp
|
||||
|
|
|
@ -6,11 +6,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "node_importer.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir_utils.h"
|
||||
#include "op_builder.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
|
@ -130,12 +130,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
if (kind == c10::prim::Constant) {
|
||||
auto output = node->output();
|
||||
MlirOperation op;
|
||||
OpBuilder builder(context);
|
||||
if (output->type()->cast<c10::NoneType>()) {
|
||||
op = builder.createNoneConstant(loc);
|
||||
op = createMlirOperation("torch.constant.none", loc,
|
||||
torchMlirTorchNoneTypeGet(context));
|
||||
} else if (output->type()->cast<c10::BoolType>()) {
|
||||
op = builder.createBoolConstant(
|
||||
loc, static_cast<bool>(node->i(c10::attr::value)));
|
||||
op = createMlirOperation(
|
||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||
toMlirNamedAttribute(
|
||||
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
|
||||
c10::attr::value)))));
|
||||
} else if (output->type()->cast<c10::IntType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.int", loc,
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
//===- op_builder.cpp -----------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "op_builder.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
OpBuilder::OpBuilder(MlirContext context) : context(context) {}
|
||||
|
||||
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
||||
return createMlirOperation("torch.constant.none", loc,
|
||||
torchMlirTorchNoneTypeGet(context));
|
||||
}
|
||||
|
||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
||||
return createMlirOperation(
|
||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
//===- op_builder.h ---------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
||||
|
||||
#include "mlir_utils.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Convenience class for centralizing creation of some operations.
|
||||
///
|
||||
/// For many operations, they are created in only one place in the code, and so
|
||||
/// the functionality in mlir_utils.h is enough.
|
||||
///
|
||||
/// TODO: Add insertion point like the normal MLIR builder?
|
||||
class OpBuilder {
|
||||
public:
|
||||
OpBuilder(MlirContext context);
|
||||
MlirOperation createNoneConstant(MlirLocation loc);
|
||||
MlirOperation createBoolConstant(MlirLocation loc, bool value);
|
||||
MlirOperation createStdConstant(MlirLocation loc, MlirAttribute value);
|
||||
|
||||
private:
|
||||
MlirContext context;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
Loading…
Reference in New Issue