2021-02-02 09:59:42 +08:00
|
|
|
//===- op_builder.cpp -----------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "op_builder.h"
|
|
|
|
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
#include "mlir-c/Diagnostics.h"
|
2021-06-15 05:13:59 +08:00
|
|
|
#include "npcomp-c/BasicpyTypes.h"
|
|
|
|
#include "npcomp-c/TorchTypes.h"
|
2021-02-02 09:59:42 +08:00
|
|
|
|
|
|
|
using namespace torch_mlir;
|
|
|
|
|
|
|
|
OpBuilder::OpBuilder(MlirContext context) : context(context) {}
|
|
|
|
|
|
|
|
MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
2021-06-15 02:36:10 +08:00
|
|
|
return createMlirOperation("torch.constant.none", loc,
|
|
|
|
npcompTorchNoneTypeGet(context));
|
2021-02-02 09:59:42 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
|
|
|
return createMlirOperation(
|
2021-06-15 05:13:59 +08:00
|
|
|
"basicpy.bool_constant", loc, npcompBasicpyBoolTypeGet(context),
|
2021-02-02 09:59:42 +08:00
|
|
|
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
|
|
|
}
|