mirror of https://github.com/llvm/torch-mlir
Remove mlir-hlo (replace with stablehlo). (#2460)
We just have to do this: I ran into an issue today where I needed to make a one line patch to stablehlo to work around a compiler issue, and it is completely unapparent how to do so given that the mlir-hlo repo is a read-only export and is at the tail end of a multi-week integration chain from the open-source stablehlo repo. We've discussed this often enough and gotten +1 from everyone that they are ok with taking the e2e testing hit if it becomes necessary: It is necessary as the current situation is unmanageable. Looking at it, I expect it wouldn't actually be very difficult to build a little runner binary out of the stablehlo interpreter and subprocess call that in order to get the testing coverage back. I leave that as an exercise to the users of this part of the stack and recommend following the breadcrumbs from the deleted python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py file and the main.py changes. Note that I am pointing us at a stablehlo fork for the moment until it is apparent that we don't need to carry any local patches to it. We can update this in a few days if everything is clear.pull/2461/head
parent
a00a0d4bfb
commit
078d1e1a1d
|
@ -1,6 +1,6 @@
|
||||||
[submodule "externals/llvm-project"]
|
[submodule "externals/llvm-project"]
|
||||||
path = externals/llvm-project
|
path = externals/llvm-project
|
||||||
url = https://github.com/llvm/llvm-project.git
|
url = https://github.com/llvm/llvm-project.git
|
||||||
[submodule "externals/mlir-hlo"]
|
[submodule "externals/stablehlo"]
|
||||||
path = externals/mlir-hlo
|
path = externals/stablehlo
|
||||||
url = https://github.com/tensorflow/mlir-hlo.git
|
url = https://github.com/openxla/stablehlo.git
|
||||||
|
|
|
@ -119,13 +119,10 @@ endif()
|
||||||
|
|
||||||
if (TORCH_MLIR_ENABLE_STABLEHLO)
|
if (TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
set(STABLEHLO_BUILD_EMBEDDED ON)
|
set(STABLEHLO_BUILD_EMBEDDED ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
${CMAKE_CURRENT_BINARY_DIR}/stablehlo
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo)
|
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
|
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
|
|
@ -321,9 +321,6 @@ function test_in_tree() {
|
||||||
echo ":::: Run make_fx + TOSA e2e integration tests"
|
echo ":::: Run make_fx + TOSA e2e integration tests"
|
||||||
python -m e2e_testing.main --config=make_fx_tosa -v
|
python -m e2e_testing.main --config=make_fx_tosa -v
|
||||||
|
|
||||||
echo ":::: Run StableHLO e2e integration tests"
|
|
||||||
python -m e2e_testing.main --config=stablehlo -v
|
|
||||||
|
|
||||||
echo ":::: Run TOSA e2e integration tests"
|
echo ":::: Run TOSA e2e integration tests"
|
||||||
python -m e2e_testing.main --config=tosa -v
|
python -m e2e_testing.main --config=tosa -v
|
||||||
}
|
}
|
||||||
|
|
|
@ -408,13 +408,18 @@ Torch-MLIR by default builds with the latest nightly PyTorch version. This can b
|
||||||
# Updating the LLVM and MLIR-HLO submodules
|
# Updating the LLVM and MLIR-HLO submodules
|
||||||
|
|
||||||
Torch-MLIR depends on `llvm-project` (which contains, among other things,
|
Torch-MLIR depends on `llvm-project` (which contains, among other things,
|
||||||
upstream MLIR) and `mlir-hlo`, both of which are submodules in the `externals/`
|
upstream MLIR) and `stablehlo`, both of which are submodules in the `externals/`
|
||||||
directory. We aim to update these at least weekly to bring in the latest
|
directory. We aim to update these at least weekly to bring in the latest
|
||||||
features and spread out over time the effort of updating our code for MLIR API
|
features and spread out over time the effort of updating our code for MLIR API
|
||||||
breakages.
|
breakages.
|
||||||
|
|
||||||
## Which LLVM commit should I pick?
|
## Which LLVM commit should I pick?
|
||||||
|
|
||||||
|
NOTE: This section is in flux. Specifically, the `mlir-hlo` dep has been
|
||||||
|
dropped and the project is running off of a `stablehlo` fork which can be
|
||||||
|
patched for certain OS combinations. As of 2023-09-12, stellaraccident@
|
||||||
|
is massaging this situation. Please reach out for advice updating.
|
||||||
|
|
||||||
Since downstream projects may want to build Torch-MLIR (and thus LLVM and
|
Since downstream projects may want to build Torch-MLIR (and thus LLVM and
|
||||||
MLIR-HLO) in various configurations (Release versus Debug builds; on Linux,
|
MLIR-HLO) in various configurations (Release versus Debug builds; on Linux,
|
||||||
Windows, or macOS; possibly with Clang, LLD, and LLDB enabled), it is crucial to
|
Windows, or macOS; possibly with Clang, LLD, and LLDB enabled), it is crucial to
|
||||||
|
|
|
@ -24,7 +24,6 @@ from torch_mlir_e2e_test.configs import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import (
|
from .xfail_sets import (
|
||||||
|
@ -44,7 +43,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
|
config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||||
parser.add_argument("-c", "--config",
|
parser.add_argument("-c", "--config",
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -52,7 +51,6 @@ def _get_argparse():
|
||||||
help=f"""
|
help=f"""
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
||||||
"stablehlo": run through torch-mlir"s default StableHLO backend.
|
|
||||||
"tosa": run through torch-mlir"s default TOSA backend.
|
"tosa": run through torch-mlir"s default TOSA backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
|
@ -100,10 +98,6 @@ def main():
|
||||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
|
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
|
||||||
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
|
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
|
||||||
crashing_set = set()
|
crashing_set = set()
|
||||||
elif args.config == "stablehlo":
|
|
||||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
|
||||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
|
||||||
crashing_set = STABLEHLO_CRASHING_SET
|
|
||||||
elif args.config == "native_torch":
|
elif args.config == "native_torch":
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
xfail_set = set()
|
xfail_set = set()
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 16886a108eff5197f816ca0f1950cc5ff1b078d9
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 77a59815a82b34f7b08ed2d42a711d9920682d0e
|
|
@ -23,14 +23,6 @@ set(LinkedLibs
|
||||||
TorchMLIRRefBackend
|
TorchMLIRRefBackend
|
||||||
)
|
)
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
|
||||||
list(APPEND LinkedLibs
|
|
||||||
MhloPasses
|
|
||||||
MhloToLinalg
|
|
||||||
StablehloToMhlo
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRInitAll
|
add_mlir_library(TorchMLIRInitAll
|
||||||
InitAll.cpp
|
InitAll.cpp
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
#include "transforms/passes.h"
|
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
@ -25,7 +26,6 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "utils/hlo_utils.h"
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
@ -34,6 +34,34 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_stablehlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
|
||||||
|
Value val) {
|
||||||
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
|
auto getAttr = [&]() -> Attribute {
|
||||||
|
if (ty.isa<mlir::IntegerType>())
|
||||||
|
return b.getIntegerAttr(ty, constant);
|
||||||
|
if (ty.isa<mlir::FloatType>())
|
||||||
|
return b.getFloatAttr(ty, constant);
|
||||||
|
if (auto complexTy = ty.dyn_cast<mlir::ComplexType>())
|
||||||
|
return complex::NumberAttr::get(complexTy, constant, 0);
|
||||||
|
llvm_unreachable("unhandled element type");
|
||||||
|
};
|
||||||
|
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
|
||||||
|
val);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
||||||
|
Value val) {
|
||||||
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
|
return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
|
||||||
|
val);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||||
mlir::Value &self, mlir::Value &other,
|
mlir::Value &self, mlir::Value &other,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
|
@ -836,7 +864,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
"for AtenReciprocalOp");
|
"for AtenReciprocalOp");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -945,7 +973,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
Value zeroTensor;
|
Value zeroTensor;
|
||||||
zeroTensor = chlo::getConstantLike(
|
zeroTensor = getConstantLike(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
false),
|
false),
|
||||||
|
@ -967,9 +995,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
return op.emitError("only ranked tensor type is supported.");
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
Value one = getConstantLike(rewriter, loc, 1.0, input);
|
||||||
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
|
Value two = getConstantLike(rewriter, loc, 2.0, input);
|
||||||
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
|
Value half = getConstantLike(rewriter, loc, 0.5, input);
|
||||||
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
|
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
|
||||||
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
||||||
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
||||||
|
@ -1485,13 +1513,12 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
||||||
}
|
}
|
||||||
// Create constant value
|
// Create constant value
|
||||||
Value kAlpha =
|
Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
||||||
chlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
|
||||||
Value cstAlpha0 =
|
Value cstAlpha0 =
|
||||||
chlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
||||||
Value half = chlo::getConstantLike(rewriter, loc, .5, input);
|
Value half = getConstantLike(rewriter, loc, .5, input);
|
||||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
Value one = getConstantLike(rewriter, loc, 1.0, input);
|
||||||
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
Value negHalf = getConstantLike(rewriter, loc, -0.5, input);
|
||||||
|
|
||||||
// Compute
|
// Compute
|
||||||
Value kBeta0 =
|
Value kBeta0 =
|
||||||
|
|
|
@ -20,7 +20,8 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRBufferTransforms
|
MLIRComplexDialect
|
||||||
|
ChloOps
|
||||||
StablehloOps
|
StablehloOps
|
||||||
TorchMLIRTorchDialect
|
TorchMLIRTorchDialect
|
||||||
TorchMLIRConversionUtils
|
TorchMLIRConversionUtils
|
||||||
|
|
|
@ -18,7 +18,9 @@ set(LinkedLibs
|
||||||
)
|
)
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
list(APPEND LinkedLibs ChloPasses)
|
list(APPEND LinkedLibs
|
||||||
|
StablehloOps
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRTorchConversionPasses
|
add_mlir_library(TorchMLIRTorchConversionPasses
|
||||||
|
|
|
@ -21,10 +21,6 @@
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
#include "torch-mlir/RefBackend/Passes.h"
|
#include "torch-mlir/RefBackend/Passes.h"
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
#include "mhlo/transforms/passes.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
registry.insert<mlir::func::FuncDialect>();
|
registry.insert<mlir::func::FuncDialect>();
|
||||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||||
|
@ -40,12 +36,4 @@ void mlir::torch::registerAllPasses() {
|
||||||
mlir::torch::registerConversionPasses();
|
mlir::torch::registerConversionPasses();
|
||||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||||
mlir::torch::TMTensor::registerPasses();
|
mlir::torch::TMTensor::registerPasses();
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
|
||||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
|
||||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
|
||||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
|
||||||
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
from torch_mlir.ir import *
|
|
||||||
from torch_mlir.passmanager import *
|
|
||||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
|
||||||
RefBackendLinalgOnTensorsBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .abc import StablehloBackend
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"LinalgOnTensorsStablehloBackend",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
|
||||||
"""Main entry-point for the linalg-on-tensors based StableHLO backend.
|
|
||||||
|
|
||||||
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
|
||||||
|
|
||||||
def compile(self, imported_module: Module):
|
|
||||||
"""Compiles an imported module that satisfied the StableHLO backend contract.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
imported_module: The MLIR module consisting of funcs in the StableHLO
|
|
||||||
dialect.
|
|
||||||
Returns:
|
|
||||||
An opaque, backend specific compiled artifact object that can be
|
|
||||||
passed to `load`.
|
|
||||||
"""
|
|
||||||
run_pipeline_with_repro_report(
|
|
||||||
imported_module,
|
|
||||||
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,mhlo-test-unfuse-batch-norm,canonicalize,hlo-legalize-to-linalg,canonicalize))",
|
|
||||||
"Lowering StableHLO to Linalg-on-Tensors",
|
|
||||||
)
|
|
||||||
return self.refbackend.compile(imported_module)
|
|
||||||
|
|
||||||
def load(self, module):
|
|
||||||
"""Loads a compiled artifact into the runtime."""
|
|
||||||
return self.refbackend.load(module)
|
|
|
@ -14,8 +14,6 @@
|
||||||
#include "torch-mlir/InitAll.h"
|
#include "torch-mlir/InitAll.h"
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mhlo/transforms/passes.h"
|
|
||||||
#include "stablehlo/dialect/Register.h"
|
#include "stablehlo/dialect/Register.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -32,12 +30,6 @@ int main(int argc, char **argv) {
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
mlir::stablehlo::registerAllDialects(registry);
|
mlir::stablehlo::registerAllDialects(registry);
|
||||||
registry.insert<mlir::mhlo::MhloDialect>();
|
|
||||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
|
||||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
|
||||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
|
||||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
|
||||||
mlir::mhlo::registerTestUnfuseBatchNormPass();
|
|
||||||
#endif
|
#endif
|
||||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||||
argc, argv, "MLIR modular optimizer driver\n", registry));
|
argc, argv, "MLIR modular optimizer driver\n", registry));
|
||||||
|
|
Loading…
Reference in New Issue