diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 9143ae5ea..7dee2041c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -25,7 +25,7 @@ STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ class LinalgOnTensorsStablehloBackend(StablehloBackend): - """Main entry-point for the linalg-on-tensors based TOSA backend. + """Main entry-point for the linalg-on-tensors based Stablehlo backend. This currently uses the linalg-on-tensors RefBackend for actual execution. """ @@ -35,11 +35,10 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): self.refbackend = RefBackendLinalgOnTensorsBackend() def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the TOSA backend contract. + """Compiles an imported module that satisfied the Stablehlo backend contract. Args: - imported_module: The MLIR module consisting of funcs in the TOSA - dialect. + imported_module: The MLIR module consisting of funcs in the Stablehlo dialect. Returns: An opaque, backend specific compiled artifact object that can be passed to `load`. diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index c37023c5e..11fcb5714 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -867,7 +867,10 @@ cc_library( hdrs = [ "include/torch-mlir/InitAll.h", ], - copts = ["-DTORCH_MLIR_ENABLE_REFBACKEND"], + copts = [ + "-DTORCH_MLIR_ENABLE_REFBACKEND", + "-DTORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPasses", @@ -882,6 +885,8 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:linalg_passes", ], )