mirror of https://github.com/llvm/torch-mlir
Remove op_builder.h/cpp
parent
900f0e04aa
commit
6f710bbc47
|
@ -17,7 +17,6 @@ add_library(TorchMLIRTorchPlugin SHARED
|
||||||
builder/function_importer.cpp
|
builder/function_importer.cpp
|
||||||
builder/module_builder.cpp
|
builder/module_builder.cpp
|
||||||
builder/node_importer.cpp
|
builder/node_importer.cpp
|
||||||
builder/op_builder.cpp
|
|
||||||
builder/ivalue_importer.cpp
|
builder/ivalue_importer.cpp
|
||||||
builder/python_bindings.cpp
|
builder/python_bindings.cpp
|
||||||
builder/torch_to_mlir_utils.cpp
|
builder/torch_to_mlir_utils.cpp
|
||||||
|
|
|
@ -6,11 +6,11 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "node_importer.h"
|
#include "node_importer.h"
|
||||||
|
#include "torch_to_mlir_utils.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlir_utils.h"
|
#include "mlir_utils.h"
|
||||||
#include "op_builder.h"
|
|
||||||
|
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
|
@ -130,12 +130,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
if (kind == c10::prim::Constant) {
|
if (kind == c10::prim::Constant) {
|
||||||
auto output = node->output();
|
auto output = node->output();
|
||||||
MlirOperation op;
|
MlirOperation op;
|
||||||
OpBuilder builder(context);
|
|
||||||
if (output->type()->cast<c10::NoneType>()) {
|
if (output->type()->cast<c10::NoneType>()) {
|
||||||
op = builder.createNoneConstant(loc);
|
op = createMlirOperation("torch.constant.none", loc,
|
||||||
|
torchMlirTorchNoneTypeGet(context));
|
||||||
} else if (output->type()->cast<c10::BoolType>()) {
|
} else if (output->type()->cast<c10::BoolType>()) {
|
||||||
op = builder.createBoolConstant(
|
op = createMlirOperation(
|
||||||
loc, static_cast<bool>(node->i(c10::attr::value)));
|
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"value", mlirBoolAttrGet(context, static_cast<bool>(node->i(
|
||||||
|
c10::attr::value)))));
|
||||||
} else if (output->type()->cast<c10::IntType>()) {
|
} else if (output->type()->cast<c10::IntType>()) {
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"torch.constant.int", loc,
|
"torch.constant.int", loc,
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
//===- op_builder.cpp -----------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// This file is licensed under a pytorch-style license
|
|
||||||
// See LICENSE for license information.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "op_builder.h"
|
|
||||||
|
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
|
||||||
#include "mlir-c/Diagnostics.h"
|
|
||||||
#include "torch-mlir-c/TorchTypes.h"
|
|
||||||
|
|
||||||
using namespace torch_mlir;
|
|
||||||
|
|
||||||
OpBuilder::OpBuilder(MlirContext context) : context(context) {}
|
|
||||||
|
|
||||||
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
|
||||||
return createMlirOperation("torch.constant.none", loc,
|
|
||||||
torchMlirTorchNoneTypeGet(context));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
|
||||||
return createMlirOperation(
|
|
||||||
"torch.constant.bool", loc, torchMlirTorchBoolTypeGet(context),
|
|
||||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
//===- op_builder.h ---------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// This file is licensed under a pytorch-style license
|
|
||||||
// See LICENSE for license information.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
|
||||||
#define TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
|
||||||
|
|
||||||
#include "mlir_utils.h"
|
|
||||||
#include "torch_to_mlir_utils.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
|
||||||
|
|
||||||
#include <ATen/Tensor.h>
|
|
||||||
#include <ATen/core/function_schema.h>
|
|
||||||
|
|
||||||
namespace torch_mlir {
|
|
||||||
|
|
||||||
/// Convenience class for centralizing creation of some operations.
|
|
||||||
///
|
|
||||||
/// For many operations, they are created in only one place in the code, and so
|
|
||||||
/// the functionality in mlir_utils.h is enough.
|
|
||||||
///
|
|
||||||
/// TODO: Add insertion point like the normal MLIR builder?
|
|
||||||
class OpBuilder {
|
|
||||||
public:
|
|
||||||
OpBuilder(MlirContext context);
|
|
||||||
MlirOperation createNoneConstant(MlirLocation loc);
|
|
||||||
MlirOperation createBoolConstant(MlirLocation loc, bool value);
|
|
||||||
MlirOperation createStdConstant(MlirLocation loc, MlirAttribute value);
|
|
||||||
|
|
||||||
private:
|
|
||||||
MlirContext context;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace torch_mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIRPLUGIN_CSRC_BUILDER_OP_BUILDER_H
|
|
Loading…
Reference in New Issue