Add bazel configs for some C APIs and pybinds

pull/3879/head
tkocmathla 2024-11-15 14:31:40 -07:00
parent 0a607a410d
commit f195e25b17
1 changed files with 95 additions and 1 deletions

View File

@ -2,7 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@llvm-project//mlir:build_defs.bzl", "mlir_c_api_cc_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library")
load("@com_github_bazelbuild_buildtools//buildifier:def.bzl", "buildifier")
package(
@ -923,3 +924,96 @@ cc_binary(
"@llvm-project//mlir:MlirOptLib",
],
)
# C API bindings
mlir_c_api_cc_library(
name = "CAPITorchRegisterEverything",
srcs = ["lib/CAPI/Registration.cpp"],
hdrs = ["include/torch-mlir-c/Registration.h"],
capi_deps = ["@llvm-project//mlir:CAPIIR"],
deps = [
":TorchMLIRInitAll",
"@llvm-project//mlir:AllPassesAndDialects",
],
)
mlir_c_api_cc_library(
name = "CAPITorch",
srcs = ["lib/CAPI/Dialects.cpp"],
hdrs = ["include/torch-mlir-c/Dialects.h"],
capi_deps = ["@llvm-project//mlir:CAPIIR"],
deps = [":TorchMLIRTorchDialect"],
)
# These flags are needed for pybind11 to work.
PYBIND11_COPTS = [
"-fexceptions",
"-frtti",
]
PYBIND11_FEATURES = [
# Cannot use header_modules (parse_headers feature fails).
"-use_header_modules",
]
# pybind11 extension module
cc_binary(
name = "_torchMlir.so",
srcs = [
"include/torch-mlir-c/Registration.h",
"python/TorchMLIRModule.cpp",
],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
linkshared = 1,
linkstatic = 0,
deps = [
":CAPITorch",
":CAPITorchRegisterEverything",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
],
)
# python files
td_library(
name = "TorchOpsPyTdFiles",
srcs = [":MLIRTorchOpsIncGenTdFiles"],
includes = ["include"],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)
gentbl_filegroup(
name = "TorchOpsPyGen",
includes = ["include"],
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=torch",
],
"python/torch_mlir/dialects/_torch_ops_gen.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "python/torch_mlir/dialects/TorchBinding.td",
deps = [
":TorchOpsPyTdFiles",
"@llvm-project//mlir:AttrTdFiles",
],
)
filegroup(
name = "TorchOpsPyFiles",
srcs = [
":TorchOpsPyGen",
],
)
filegroup(
name = "TorchPyFiles",
srcs = glob(["python/**/*.py"]),
)