Remove op_builder.h/cpp

pull/309/head
Sean Silva 2021-09-17 03:56:26 +00:00
parent 900f0e04aa
commit 6f710bbc47
4 changed files with 8 additions and 74 deletions

View File

@ -17,7 +17,6 @@ add_library(TorchMLIRTorchPlugin SHARED
builder/function_importer.cpp
builder/module_builder.cpp
builder/node_importer.cpp
builder/op_builder.cpp
builder/ivalue_importer.cpp
builder/python_bindings.cpp
builder/torch_to_mlir_utils.cpp

View File

@ -6,11 +6,11 @@
//===----------------------------------------------------------------------===//
#include "node_importer.h"
#include "torch_to_mlir_utils.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "op_builder.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
@ -130,12 +130,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
if (kind == c10::prim::Constant) {
auto output = node->output();
MlirOperation op;
OpBuilder builder(context);
if (output->type()->cast<c10::NoneType>()) {
op = builder.createNoneConstant(loc);
op = createMlirOperation("torch.constant.none", loc,
torchMlirTorchNoneTypeGet(context));
} else if (output->type()->cast<c10::BoolType>()) {
op = builder.createBoolConstant(
loc, static_cast<bool>(node->i(c10::attr::value)));
op = createMlirOperation(
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute(
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
c10::attr::value)))));
} else if (output->type()->cast<c10::IntType>()) {
op = createMlirOperation(
"torch.constant.int", loc,

View File

@ -1,28 +0,0 @@
//===- op_builder.cpp -----------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "op_builder.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
OpBuilder::OpBuilder(MlirContext context) : context(context) {}
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
return createMlirOperation("torch.constant.none", loc,
torchMlirTorchNoneTypeGet(context));
}
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
return createMlirOperation(
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
}

View File

@ -1,40 +0,0 @@
//===- op_builder.h ---------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
#define TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/IR.h"
#include <ATen/Tensor.h>
#include <ATen/core/function_schema.h>
namespace torch_mlir {
/// Convenience class for centralizing creation of some operations.
///
/// For many operations, they are created in only one place in the code, and so
/// the functionality in mlir_utils.h is enough.
///
/// TODO: Add insertion point like the normal MLIR builder?
class OpBuilder {
public:
OpBuilder(MlirContext context);
MlirOperation createNoneConstant(MlirLocation loc);
MlirOperation createBoolConstant(MlirLocation loc, bool value);
MlirOperation createStdConstant(MlirLocation loc, MlirAttribute value);
private:
MlirContext context;
};
} // namespace torch_mlir
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H