mirror of https://github.com/llvm/torch-mlir
mhlo: migrate conversion to stablehlo (#1840)
This patch replaces all MHLO operations with their StableHLO counterparts and adds a validation pass to ensure that no MHLO operations remain before translating all Stablehlo operations to the MHLO dialect for further lowering to the Linalg dialect. This patch also updates all lit tests so that they refer to the `convert-torch-to-stablehlo` pass and so that they check for StableHLO operations.pull/1851/head
parent
ed9d8d1fb7
commit
711646d095
|
@ -113,7 +113,7 @@ jobs:
|
|||
-DLLVM_USE_HOST_TOOLS=ON \
|
||||
-DLLVM_ENABLE_ZSTD=OFF \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DTORCH_MLIR_ENABLE_MHLO=OFF \
|
||||
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
|
||||
-DTORCH_MLIR_ENABLE_LTC=OFF \
|
||||
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
|
||||
-DMACOSX_DEPLOYMENT_TARGET=12.0 \
|
||||
|
|
|
@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
|
|||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||
endmacro()
|
||||
|
||||
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
||||
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
endif()
|
||||
|
||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||
|
@ -128,8 +128,8 @@ else()
|
|||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
||||
endif()
|
||||
|
||||
if (TORCH_MLIR_ENABLE_MHLO)
|
||||
set(MHLO_BUILD_EMBEDDED ON)
|
||||
if (TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
set(STABLEHLO_BUILD_EMBEDDED ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
||||
EXCLUDE_FROM_ALL)
|
||||
|
|
|
@ -267,8 +267,8 @@ function test_in_tree() {
|
|||
echo ":::: Run Linalg e2e integration tests"
|
||||
python -m e2e_testing.main --config=linalg -v
|
||||
|
||||
echo ":::: Run MHLO e2e integration tests"
|
||||
python -m e2e_testing.main --config=mhlo -v
|
||||
echo ":::: Run StableHLO e2e integration tests"
|
||||
python -m e2e_testing.main --config=stablehlo -v
|
||||
|
||||
echo ":::: Run TOSA e2e integration tests"
|
||||
python -m e2e_testing.main --config=tosa -v
|
||||
|
|
|
@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
|
|||
|
||||
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
|
||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
||||
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||
|
||||
The terms "frontend" and "backend" are highly overloaded in any compiler
|
||||
project, but frequently in Torch-MLIR this is the meaning that they have.
|
||||
Sometimes "frontend" can mean something even further up the stack, such as
|
||||
something in PyTorch itself. When there is ambiguity we will refer to this as
|
||||
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
|
||||
sitting below Linalg-on-Tensors, TOSA, or MHLO.
|
||||
sitting below Linalg-on-Tensors, TOSA, or StableHLO.
|
||||
|
||||
## The `torch` dialect
|
||||
|
||||
|
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
|
|||
|
||||
The backend contract is a normalized form of the `torch` dialect with a set of
|
||||
properties that make it easy to lower into various forms such as
|
||||
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
|
||||
box. The primary guarantees that we provide Torch-MLIR's backends are:
|
||||
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
|
||||
the box. The primary guarantees that we provide Torch-MLIR's backends are:
|
||||
|
||||
- All tensors have been converted to value semantics.
|
||||
- All tensors have at least a known number of dimensions (i.e. rank), and
|
||||
|
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
|
|||
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
|
||||
`tensor`, etc.)
|
||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
||||
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||
|
||||
### The Linalg Backend (Linalg-on-Tensors)
|
||||
|
||||
|
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
|
|||
- It is extremely solid with static shapes (and many of its users only care
|
||||
about static shapes, so that's fine).
|
||||
|
||||
### The MHLO Backend
|
||||
### The StableHLO Backend
|
||||
|
||||
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
|
||||
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
|
||||
|
||||
The MHLO backend was the third backend that we added, and it offers a reasonable
|
||||
blend of the benefits of the other two.
|
||||
The StableHLO backend was the third backend that we added, and it offers a
|
||||
reasonable blend of the benefits of the other two.
|
||||
- It is a coarse-grained named-op approach.
|
||||
- It has a pretty clear spec for most of the ops (with a bit of mental
|
||||
translation and hoping that MHLO is the same as HLO):
|
||||
translation and hoping that StableHLO is the same as HLO):
|
||||
https://www.tensorflow.org/xla/operation_semantics
|
||||
- It functionally supports dynamic shapes (though not as coherent and consistent
|
||||
as Linalg-on-Tensors, and the dynamic shape support falls outside the
|
||||
|
@ -317,7 +317,7 @@ blend of the benefits of the other two.
|
|||
example, TOSA limits (for highly considered reasons) the number of dimensions
|
||||
that certain operators can handle to 1D-4D, when from a purely algebraic
|
||||
perspective there isn't a good reason to not be more general. Similarly, more
|
||||
general forms of reduction and scatter also fall into MHLO nicely while
|
||||
general forms of reduction and scatter also fall into StableHLO nicely while
|
||||
TOSA's principles tend to bias it away from that.
|
||||
|
||||
### Backend Implementation
|
||||
|
@ -433,8 +433,9 @@ filling in some corners missing upstream and
|
|||
to pull together upstream functionality into a working system.
|
||||
|
||||
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
|
||||
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
|
||||
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
|
||||
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
|
||||
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
|
||||
RefBackend.
|
||||
|
||||
The RefBackend is absolutely not suitable for any production use case. It leaks
|
||||
memory, doesn't support any error handling, performs no optimizations, and
|
||||
|
|
|
@ -34,7 +34,7 @@ and Clang's
|
|||
- Eric Kunze (@eric-k256)
|
||||
- Suraj Sudhir (@sjarus)
|
||||
|
||||
### TorchToMHLO
|
||||
### TorchToStablehlo
|
||||
|
||||
- Tianyo Kwok (@tanyokwok)
|
||||
- Ziheng Jiang (@ZihengJiang)
|
||||
|
|
|
@ -139,7 +139,7 @@ Ex:
|
|||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
```
|
||||
|
||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
|
||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||
|
||||
## Jupyter
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ the ecosystem are:
|
|||
|
||||
- The frontend work required to lower TorchScript to the backend contract.
|
||||
- The irregular support surface area of the large number of PyTorch ops across
|
||||
the Linalg, TOSA, and MHLO backends.
|
||||
the Linalg, TOSA, and StableHLO backends.
|
||||
|
||||
Most of this document describes long-term ecosystem changes that will address
|
||||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||||
|
@ -108,7 +108,7 @@ more advanced).
|
|||
### Refactoring the backend
|
||||
|
||||
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
|
||||
TOSA, and MHLO. These backends take IR in the backend contract form (see
|
||||
TOSA, and StableHLO. These backends take IR in the backend contract form (see
|
||||
[architecture.md](architecture.md)) and lowers them to the respective dialects.
|
||||
Today, each backend is implemented completely independently. This leads to
|
||||
duplication and irregularity across the backends.
|
||||
|
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
|
|||
forward-looking efforts that intersect with this effort:
|
||||
|
||||
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
|
||||
initially forked from MHLO which intends to create a stable support surface
|
||||
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
|
||||
complete op set, so it is very attractive to have "almost all" models
|
||||
bottleneck through a stable interface like StableHLO. StableHLO is currently
|
||||
under relatively early development, but already delivers on many of the goals
|
||||
of stability.
|
||||
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
|
||||
attractive to have "almost all" models bottleneck through a stable interface
|
||||
like StableHLO. StableHLO is currently under relatively early development,
|
||||
but already delivers on many of the goals of stability.
|
||||
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
|
||||
which could serve a role very similar to MHLO, while providing community
|
||||
ownership. TCP is still in early planning phases, but there is strong
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
|
|||
from torch_mlir_e2e_test.configs import (
|
||||
LazyTensorCoreTestConfig,
|
||||
LinalgOnTensorsBackendTestConfig,
|
||||
MhloBackendTestConfig,
|
||||
StablehloBackendTestConfig,
|
||||
NativeTorchTestConfig,
|
||||
TorchScriptTestConfig,
|
||||
TosaBackendTestConfig,
|
||||
|
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
|
|||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||
parser.add_argument("-c", "--config",
|
||||
choices=config_choices,
|
||||
|
@ -42,7 +42,7 @@ def _get_argparse():
|
|||
help=f"""
|
||||
Meaning of options:
|
||||
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
||||
"mhlo": run through torch-mlir"s default MHLO backend.
|
||||
"stablehlo": run through torch-mlir"s default StableHLO backend.
|
||||
"tosa": run through torch-mlir"s default TOSA backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
|
@ -80,9 +80,9 @@ def main():
|
|||
if args.config == "tosa":
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
if args.config == "mhlo":
|
||||
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
|
||||
xfail_set = all_test_unique_names - MHLO_PASS_SET
|
||||
if args.config == "stablehlo":
|
||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||
elif args.config == "native_torch":
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = {}
|
||||
|
|
|
@ -87,8 +87,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"StdCorrectionKeepDimModule_basic",
|
||||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
STABLEHLO_PASS_SET = {
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddSizeIntModule_basic",
|
||||
"AddSizeIntNegDimModule_basic",
|
||||
"ArangeDtypeFloatModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
|
@ -103,6 +105,7 @@ MHLO_PASS_SET = {
|
|||
"ArangeStartStepFloatModule_basic",
|
||||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"BatchMlpLayerModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
|
@ -124,12 +127,15 @@ MHLO_PASS_SET = {
|
|||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"ElementwiseLeakyReluModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSqrtModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
|
@ -198,6 +204,8 @@ MHLO_PASS_SET = {
|
|||
"Gather2DInputModdule_basic",
|
||||
"GatherRandomIndexModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
|
@ -220,6 +228,8 @@ MHLO_PASS_SET = {
|
|||
"MeanDynamicSizesModule_basic",
|
||||
"MeanLargeInputModule_basic",
|
||||
"MeanModule_basic",
|
||||
"Mlp1LayerModule_basic",
|
||||
"Mlp2LayerModule_basic",
|
||||
"MmTanhModule_basic",
|
||||
"Mv_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
|
@ -251,6 +261,8 @@ MHLO_PASS_SET = {
|
|||
"LiftFreshCopyModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
"NumelModule_basic",
|
||||
"SiluModule_basic",
|
||||
"SquareModule_basic",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"ViewCollapseOnesMiddleModule_basic",
|
||||
|
@ -420,6 +432,7 @@ MHLO_PASS_SET = {
|
|||
"UnsafeViewDynamicExpandModule_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
"TestF16Return_basic",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")
|
|
@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
|
|||
model = BertTinyWrapper()
|
||||
model.eval()
|
||||
data = torch.randint(30522, (2, 128))
|
||||
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
|
||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
|
||||
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")
|
|
@ -1,6 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
else()
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
endif()
|
||||
|
|
|
@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
|
|||
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to MHLO ops";
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to Stablehlo ops";
|
||||
let description = [{
|
||||
Convert Torch ops to mhlo ops.
|
||||
Convert Torch ops to Stablehlo ops.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
||||
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
|
||||
|
||||
// Specify any options.
|
||||
let options = [
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -16,10 +16,11 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
|
||||
createConvertTorchToStablehloPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
@ -1,6 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
else()
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
endif()
|
||||
|
|
|
@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
|
|||
/// TOSA backend contract.
|
||||
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
||||
|
||||
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
struct MhloBackendPipelineOptions
|
||||
: public PassPipelineOptions<MhloBackendPipelineOptions> {
|
||||
// Do not register the stablehlo options if the stablehlo target is disabled
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
struct StablehloBackendPipelineOptions
|
||||
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
|
||||
Option<bool> enableStaticShape{
|
||||
*this, "enable-static-shape",
|
||||
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
||||
|
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
|
|||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
void createTorchBackendToMhloBackendPipeline(
|
||||
OpPassManager &pm, const MhloBackendPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
|
||||
void createTorchBackendToStablehloBackendPipeline(
|
||||
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyStablehloBackendContractPass();
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||
|
|
|
@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
|
|||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the mhlo backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the stablehlo backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||
|
|
|
@ -3,13 +3,7 @@ add_subdirectory(Conversion)
|
|||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
|
||||
add_mlir_library(TorchMLIRInitAll
|
||||
InitAll.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
set(LinkedLibs
|
||||
MLIRFuncDialect
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
|
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
|
|||
TorchMLIRRefBackend
|
||||
)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND LinkedLibs
|
||||
MhloPasses
|
||||
MhloToLinalg
|
||||
StablehloToMhlo
|
||||
)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRInitAll
|
||||
InitAll.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
${LinkedLibs}
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRInitAll)
|
||||
|
|
|
@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
|
|||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToArith)
|
||||
add_subdirectory(TorchToTosa)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_subdirectory(TorchToMhlo)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
add_subdirectory(TorchToStablehlo)
|
||||
endif()
|
||||
add_subdirectory(TorchToTMTensor)
|
||||
add_subdirectory(TorchConversionToMLProgram)
|
||||
|
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
|
|||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchConversionToMLProgram
|
||||
TorchMLIRConversionUtils)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
list(APPEND linked_libs
|
||||
MhloPasses
|
||||
TorchMLIRTorchToMhlo)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRConversionPasses
|
||||
|
|
|
@ -9,15 +9,15 @@
|
|||
|
||||
#include "torch-mlir/Conversion/Passes.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#include "transforms/passes.h"
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
|
||||
|
@ -32,12 +32,4 @@ namespace {
|
|||
|
||||
void mlir::torch::registerConversionPasses() {
|
||||
::registerPasses();
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
||||
});
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::mhlo::createSymbolicShapeOptimizationPass();
|
||||
});
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
}
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||
TorchToMhlo.cpp
|
||||
MhloLegalizeUtils.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Pooling.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
||||
|
||||
DEPENDS
|
||||
MhloDialect
|
||||
MhloToLinalg
|
||||
MLIRMhloPassIncGen
|
||||
LMHLOTransformsPassIncGen
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MhloDialect
|
||||
MhloToLinalg
|
||||
MLIRBufferTransforms
|
||||
StablehloOps
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRConversionUtils
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToMhlo)
|
|
@ -1,74 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace torch_to_mhlo {
|
||||
|
||||
struct TorchToMhloOptions {
|
||||
bool enableStaticShape = false;
|
||||
size_t dimSizeIndexBits = 64;
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||
const TorchToMhloOptions &options)
|
||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||
this->options = options;
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||
}
|
||||
const TorchToMhloOptions &getOptions() const { return options; }
|
||||
|
||||
private:
|
||||
TorchToMhloOptions options;
|
||||
};
|
||||
|
||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
} // namespace torch_to_mhlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -29,7 +30,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||
mlir::Value &self, mlir::Value &other,
|
||||
|
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
|||
if (selfRank > otherRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
other = *unsqueezeInfo;
|
||||
} else if (otherRank > selfRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
|
||||
dimSizeIndexBits);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
self = *unsqueezeInfo;
|
||||
|
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
|||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
|
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
|||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
|
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
|
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
|
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
|
||||
// These legalizations are for unary ops.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename MhloOpT>
|
||||
template <typename AtenOpT, typename StablehloOpT>
|
||||
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
|
@ -137,13 +142,13 @@ public:
|
|||
Value self = adaptor.getSelf();
|
||||
auto selfType = self.getType().cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
self = mhlo::promoteType(rewriter, self, outType);
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
|
||||
self = hlo::promoteType(rewriter, self, outType);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -152,7 +157,7 @@ public:
|
|||
// These legalizations are for unary ops with only for floating point datatypes.
|
||||
// There is no supported quantized integer mode for these.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename MhloOpT>
|
||||
template <typename AtenOpT, typename StablehloOpT>
|
||||
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
|
@ -164,10 +169,10 @@ public:
|
|||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -198,7 +203,7 @@ public:
|
|||
.template dyn_cast<TensorType>();
|
||||
|
||||
if (!outType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -216,9 +221,9 @@ public:
|
|||
|
||||
SmallVector<int32_t> values(size, fillVal);
|
||||
auto constOp =
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -247,8 +252,8 @@ public:
|
|||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outTy);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
@ -274,7 +279,7 @@ public:
|
|||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -287,18 +292,19 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsType) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
if (isa<AtenRsubScalarOp>(op)) {
|
||||
std::swap(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
|
||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||
Value alpha =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
|
||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||
adaptor.getAlpha(), outElemTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||
bcastDimensions);
|
||||
|
@ -328,7 +334,7 @@ public:
|
|||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -343,11 +349,12 @@ public:
|
|||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhs = lhs;
|
||||
} else if (!rhsType) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
@ -368,15 +375,15 @@ public:
|
|||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
|
||||
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
|
||||
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
|
||||
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
|
||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
|
||||
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
||||
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
|
@ -401,7 +408,7 @@ public:
|
|||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -414,11 +421,12 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsTy) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
lhsElemTy);
|
||||
}
|
||||
|
||||
// TODO: what is the PyTorch default type promotion?
|
||||
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
|
||||
|
||||
chlo::ComparisonTypeAttr compareTypeAttr;
|
||||
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||
|
@ -485,8 +493,8 @@ public:
|
|||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
|
@ -537,8 +545,8 @@ public:
|
|||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||
rewriter.getI64Type()),
|
||||
permValues);
|
||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
Value self = adaptor.getSelf();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
} else {
|
||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
|
||||
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
|
||||
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
|
||||
inputRank);
|
||||
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
|
||||
rewriter.getIndexType(), dim);
|
||||
}
|
||||
|
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||
AtenWhereSelfOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
AtenWhereSelfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
Value cond = adaptor.getCondition();
|
||||
Value other = adaptor.getOther();
|
||||
|
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
return op.emitError("failed broadcast other and condition ranks");
|
||||
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
||||
op,
|
||||
getTypeConverter()->convertType(op.getType()),
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
ArrayRef<Value>{cond, self, other});
|
||||
return success();
|
||||
}
|
||||
|
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
.cast<RankedTensorType>();
|
||||
|
||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
return success();
|
||||
}
|
||||
|
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
op->getLoc(), ValueRange{bcastShapeVec});
|
||||
auto dimensionNumbers =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op, outType, self, bcastShapeTensor,
|
||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||
}
|
||||
|
@ -708,8 +715,8 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||
rewriter.getI64Type()),
|
||||
permValues);
|
||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -721,7 +728,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
|||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
} else {
|
||||
|
@ -751,16 +758,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
||||
return APInt(bitWidth, v.getSExtValue());
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||
valueAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
|
||||
adaptor.getValue());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||
adaptor.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenReciprocalOp
|
||||
// Reciprocal(x) = Div(1, x)
|
||||
template <>
|
||||
|
@ -777,7 +784,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -790,9 +797,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outputElemType = outputType.getElementType();
|
||||
Value mhloTensor =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
|
||||
rewriter.replaceOp(op, mhloTensor);
|
||||
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
||||
rewriter, op, adaptor.getA(), outputElemType);
|
||||
rewriter.replaceOp(op, stablehloTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -815,7 +822,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenReluOp
|
||||
// Relu(x) = Max(0, x)
|
||||
template <>
|
||||
|
@ -836,11 +842,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
false),
|
||||
lhs);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// Convert a Aten::GELU to HLO
|
||||
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
||||
template <>
|
||||
|
@ -857,12 +862,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
|
||||
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
|
||||
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
|
||||
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
|
||||
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
|
||||
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
||||
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
||||
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
|
||||
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
||||
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
|
||||
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -881,7 +886,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenBatchNormOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||
|
@ -919,28 +923,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{channelDim});
|
||||
if (failed(checkNotNone(rewriter, op, weight))) {
|
||||
weight = mhlo::getConstantOfShape(
|
||||
weight = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, bias))) {
|
||||
bias = mhlo::getConstantOfShape(
|
||||
bias = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
||||
runningVar = mhlo::getConstantOfShape(
|
||||
runningVar = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
||||
runningMean = mhlo::getConstantOfShape(
|
||||
runningMean = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
|
@ -983,10 +987,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
Type batchMeanOrVarTy =
|
||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
||||
return success();
|
||||
} else {
|
||||
|
@ -995,10 +1000,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
inputTy.getShape().end()};
|
||||
castShape[1] = weightTy.getShape()[0];
|
||||
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
||||
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
|
||||
// Feature counts must match among operands of
|
||||
// stablehlo::BatchNormInferenceOp.
|
||||
Value inputCasted =
|
||||
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
||||
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
|
||||
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
runningMean, runningVar,
|
||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||
|
@ -1008,7 +1014,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// AtenNativeLayerNormOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||
|
@ -1076,21 +1081,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
}
|
||||
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
|
||||
numEmbeddingDimSize};
|
||||
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
|
||||
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
|
||||
|
||||
auto mhloBatchNormOutTy =
|
||||
auto stablehloBatchNormOutTy =
|
||||
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
|
||||
auto mhloBathNormOutMeanOrVarTy =
|
||||
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
|
||||
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
|
||||
meanOrVarStablehloOutShape, inputTy.getElementType());
|
||||
|
||||
// Reshape input
|
||||
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, input,
|
||||
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), stablehloBatchNormOutTy, input,
|
||||
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
.value());
|
||||
|
||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
||||
SmallVector<APFloat> zeroConstVec(
|
||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
|
@ -1103,16 +1108,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto oneOrZeroConstType =
|
||||
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
||||
|
||||
Value scale = rewriter.create<mhlo::ConstantOp>(
|
||||
Value scale = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), oneOrZeroConstType,
|
||||
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
|
||||
Value offset = rewriter.create<mhlo::ConstantOp>(
|
||||
Value offset = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), oneOrZeroConstType,
|
||||
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
|
||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
|
||||
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
|
||||
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op->getLoc(), stablehloBatchNormOutTy,
|
||||
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
|
||||
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// Reshape back
|
||||
auto outputTy =
|
||||
|
@ -1120,36 +1127,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto outputMeanOrVarTy =
|
||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||
|
||||
auto output = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||
.value());
|
||||
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
||||
mhlo::getConstTensor(
|
||||
hlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.value());
|
||||
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
||||
mhlo::getConstTensor(
|
||||
hlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.value());
|
||||
|
||||
// Apply affine transform: output x weight + bias [element-wise]
|
||||
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||
auto outputMulWeight =
|
||||
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||
auto finalOuput =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
|
||||
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||
auto finalOuput = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), outputMulWeight, bcastedBias);
|
||||
rewriter.replaceOp(op, {finalOuput, mean, var});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenCatOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||
|
@ -1173,11 +1179,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
|
||||
// Promote type
|
||||
for (auto &v : builtinTensors) {
|
||||
v = mhlo::promoteType(rewriter, v, outType);
|
||||
v = hlo::promoteType(rewriter, v, outType);
|
||||
}
|
||||
|
||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||
op, outType, ValueRange(builtinTensors), posDim);
|
||||
return success();
|
||||
}
|
||||
|
@ -1225,7 +1231,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "this op should be folded as its `min` and `max` both are none");
|
||||
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
maxValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(minInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1233,7 +1240,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
}
|
||||
minValue = *minInfo;
|
||||
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
minValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(maxInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1241,10 +1249,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
}
|
||||
maxValue = *maxInfo;
|
||||
} else {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
minValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||
maxValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
||||
maxValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1266,24 +1277,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
op, "unimplemented: only int or float dtype supported");
|
||||
}
|
||||
|
||||
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
|
||||
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
||||
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
|
||||
Value start =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
|
||||
Value end =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
||||
Value step =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
|
||||
|
||||
// Get length of the 1-d output tensor
|
||||
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
|
||||
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
|
||||
|
||||
Value resultLength = rewriter.create<mhlo::ReshapeOp>(
|
||||
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({1}, dtype), divOut);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
|
||||
resultLength = rewriter.create<mhlo::ConvertOp>(
|
||||
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
|
||||
resultLength = rewriter.create<stablehlo::ConvertOp>(
|
||||
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
||||
}
|
||||
|
||||
Value window =
|
||||
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
DenseIntElementsAttr broadcastDimensions;
|
||||
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
||||
broadcastDimensions);
|
||||
|
@ -1298,9 +1312,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto outType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<TensorType>();
|
||||
auto outType =
|
||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
if (!outType) {
|
||||
return op.emitError("only tensor type is supported");
|
||||
}
|
||||
|
@ -1320,26 +1333,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
||||
|
||||
// Compute
|
||||
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
|
||||
Value erfArg =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
|
||||
Value kBeta0 =
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
|
||||
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
|
||||
adaptor.getSelf());
|
||||
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
||||
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
|
||||
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
|
||||
Value inputSquared = rewriter.create<mhlo::MulOp>(
|
||||
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
|
||||
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
|
||||
Value inputSquared = rewriter.create<stablehlo::MulOp>(
|
||||
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
||||
Value negHalfInputSquared =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||
Value expRes =
|
||||
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
|
||||
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
|
||||
Value pdfTimesInput =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||
Value pdfTimesInputAddCdf =
|
||||
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
|
||||
pdfTimesInputAddCdf);
|
||||
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
|
||||
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1366,9 +1380,9 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
|
@ -1376,23 +1390,24 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
target.addIllegalOp<RuntimeAssertOp>();
|
||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
||||
context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
|
@ -1482,10 +1497,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
|
||||
context)
|
||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
|
||||
typeConverter, context)
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
|
@ -0,0 +1,29 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||
TorchToStablehlo.cpp
|
||||
StablehloLegalizeUtils.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Pooling.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRBufferTransforms
|
||||
StablehloOps
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRConversionUtils
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -24,7 +25,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
||||
// indexVecDim
|
||||
int64_t indexVecDim = indicesRank;
|
||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/offsetDims,
|
||||
/*collapsedSliceDims=*/collapsedSliceDims,
|
||||
|
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
auto outputTy =
|
||||
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
||||
return rewriter
|
||||
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||||
sliceSizesTensor, dimsAttr)
|
||||
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||||
sliceSizesTensor, dimsAttr)
|
||||
.getResult();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||
// Ref:
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||
// padding_idx (int, optional)
|
||||
// – If specified, the entries at padding_idx do not contribute to the gradient;
|
||||
// therefore, the embedding vector at padding_idx is not updated during training,
|
||||
// i.e. it remains as a fixed “pad”.
|
||||
// – If specified, the entries at padding_idx do not contribute to the
|
||||
// gradient; therefore, the embedding vector at padding_idx is not updated
|
||||
// during training, i.e. it remains as a fixed “pad”.
|
||||
// scale_grad_by_freq (boolean, optional)
|
||||
// – If given, this will scale gradients by the inverse of frequency of the
|
||||
// words in the mini-batch. Default False.
|
||||
|
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
|
||||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
return success();
|
||||
|
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
return success();
|
||||
|
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
|
||||
auto options = getOptions();
|
||||
auto indexShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
if (failed(indexShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dim sizes of `index` param");
|
||||
|
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
SmallVector<Value> toConcat;
|
||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||
if (i == dim) {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
loc, toConcatIndexType, index, toConcatIndexShape));
|
||||
} else {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
|
||||
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
loc, toConcatIndexType, toConcatIndexShape,
|
||||
rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
}
|
||||
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
|
||||
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
||||
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
||||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
||||
|
||||
|
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
startIndexMap.push_back(i);
|
||||
}
|
||||
|
||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/{},
|
||||
/*collapsedSliceDims=*/collapsedDims,
|
||||
/*startIndexMap=*/startIndexMap,
|
||||
/*indexVecDim=*/indexVecDim);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||
op, input, gatherIndicies, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -25,7 +26,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||
|
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
ArrayRef<int64_t> broadcastDims) {
|
||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto loc = op->getLoc();
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
|
||||
RankedTensorType outTy =
|
||||
RankedTensorType::get(shape, tensorTy.getElementType());
|
||||
|
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
rewriter.getIntegerType(64));
|
||||
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
||||
|
||||
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, mhloShape, broadcastAttr);
|
||||
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, stablehloShape, broadcastAttr);
|
||||
return broadcast;
|
||||
}
|
||||
|
||||
|
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
ArrayRef<int64_t> inpTransDims) {
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto rank = inputTy.getRank();
|
||||
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
|
||||
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
||||
auto inpShape = inputTy.getShape();
|
||||
std::vector<int64_t> newShape;
|
||||
newShape.reserve(rank);
|
||||
|
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
|
||||
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
|
||||
permuteAttr);
|
||||
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
|
||||
input, permuteAttr);
|
||||
return result.getResult();
|
||||
}
|
||||
|
||||
|
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
// set result dimensions
|
||||
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
|
||||
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
|
||||
lhsResultDim >= 0) {
|
||||
outShape.push_back(lhsShape[lhsResultDim]);
|
||||
}
|
||||
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
|
||||
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
|
||||
rhsResultDim >= 0) {
|
||||
outShape.push_back(rhsShape[rhsResultDim]);
|
||||
}
|
||||
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
||||
|
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||
rhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto lhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||
lhsDimSizes.end());
|
||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||
|
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||
lhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto rhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||
rhsDimSizes.end());
|
||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||
|
@ -218,8 +221,8 @@ public:
|
|||
if (lhsRank <= 2 && rhsRank <= 2) {
|
||||
auto tensorType =
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
|
||||
nullptr);
|
||||
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
|
||||
rhs, nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -253,8 +256,8 @@ public:
|
|||
lhsContractingDim = nBatchDims;
|
||||
}
|
||||
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
stablehlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
|
@ -264,8 +267,8 @@ public:
|
|||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||
lhsContractingDim, rhsContractingDim);
|
||||
output = rewriter
|
||||
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
@ -312,7 +315,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -335,7 +338,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -371,7 +374,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -398,10 +401,10 @@ public:
|
|||
auto bias = adaptor.getBias();
|
||||
auto biasTy = bias.getType();
|
||||
|
||||
// MHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<RankedTensorType>())
|
||||
return op.emitError("only ranked tensor types are supported in MHLO "
|
||||
return op.emitError("only ranked tensor types are supported in StableHLO "
|
||||
"matmul for bias tensor");
|
||||
|
||||
// weight.T
|
||||
|
@ -427,14 +430,14 @@ public:
|
|||
auto outTy =
|
||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||
lhsContractingDim, rhsContractingDim);
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
stablehlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
|
@ -464,7 +467,7 @@ public:
|
|||
auto weightElemTy = weightTy.getElementType();
|
||||
auto rank = weightTy.getRank();
|
||||
const auto &options = getOptions();
|
||||
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
|
||||
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
|
||||
rewriter, op, weight, options.dimSizeIndexBits);
|
||||
auto weightShape = weightTy.getShape();
|
||||
SmallVector<int64_t> weightShapeInt(rank);
|
||||
|
@ -488,7 +491,7 @@ public:
|
|||
}
|
||||
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), weightShapeVec);
|
||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||
weight, weightShapeTensor);
|
||||
|
||||
|
@ -497,7 +500,7 @@ public:
|
|||
for (int64_t i = 0; i <= rank; i++)
|
||||
transposeDims[i] = i;
|
||||
std::swap(transposeDims[1], transposeDims[0]);
|
||||
weight = rewriter.create<mhlo::TransposeOp>(
|
||||
weight = rewriter.create<stablehlo::TransposeOp>(
|
||||
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
|
||||
|
||||
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
|
||||
|
@ -509,7 +512,7 @@ public:
|
|||
weightShapeVec[1] = OCMulGValue;
|
||||
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), weightShapeVec);
|
||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||
weight, weightShapeTensor);
|
||||
return weight;
|
||||
|
@ -544,25 +547,27 @@ public:
|
|||
}
|
||||
|
||||
// Prepare for transposed convolution
|
||||
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
|
||||
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
|
||||
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
|
||||
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
|
||||
DenseIntElementsAttr stablehloStride =
|
||||
rewriter.getI64TensorAttr(stablehloStrideVec);
|
||||
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
|
||||
for (int i = 0; i < nSpatialDims; ++i) {
|
||||
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
|
||||
mhloPaddingVec[i * 2] = padInt;
|
||||
mhloPaddingVec[i * 2 + 1] = padInt;
|
||||
stablehloPaddingVec[i * 2] = padInt;
|
||||
stablehloPaddingVec[i * 2 + 1] = padInt;
|
||||
}
|
||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
||||
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
|
||||
mhloPaddingVec);
|
||||
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
|
||||
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
|
||||
DenseIntElementsAttr mhloLhsDilation =
|
||||
rewriter.getI64TensorAttr(mhloLhsDilationVec);
|
||||
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
|
||||
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
|
||||
DenseIntElementsAttr mhloRhsDilation =
|
||||
rewriter.getI64TensorAttr(mhloRhsDilationVec);
|
||||
stablehloPaddingVec);
|
||||
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
|
||||
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloLhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
|
||||
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloRhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloRhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
|
||||
|
||||
DenseElementsAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
@ -571,8 +576,8 @@ public:
|
|||
for (int i = 0; i < nSpatialDims; ++i) {
|
||||
spatialDims.push_back(i + 2);
|
||||
}
|
||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
mhlo::ConvDimensionNumbersAttr::get(
|
||||
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
stablehlo::ConvDimensionNumbersAttr::get(
|
||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||
/*inputFeatureDimension=*/1,
|
||||
/*inputSpatialDimensions=*/spatialDims,
|
||||
|
@ -583,17 +588,18 @@ public:
|
|||
/*outputSpatialDimensions=*/spatialDims);
|
||||
|
||||
// Reverse and transpose weight
|
||||
weight = rewriter.create<mhlo::ReverseOp>(
|
||||
weight = rewriter.create<stablehlo::ReverseOp>(
|
||||
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
|
||||
if (groups != 1) {
|
||||
weight = reshapeConvWeight(rewriter, op, weight, groups);
|
||||
}
|
||||
|
||||
// Create transposed convolution
|
||||
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
||||
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
|
||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
||||
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||
op->getLoc(), convOutTy, input, weight, stablehloStride,
|
||||
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||
precisionConfig);
|
||||
|
||||
// Handle output padding
|
||||
if (!needHandleOutputPadding) {
|
||||
|
@ -605,8 +611,8 @@ public:
|
|||
std::copy(outputPadding.begin(), outputPadding.end(),
|
||||
edgePaddingHighVec.begin() + 2);
|
||||
Value paddingValue =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
||||
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
|
||||
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
||||
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
|
||||
mlir::DenseIntElementsAttr edgePaddingLow =
|
||||
rewriter.getI64VectorAttr(edgePaddingLowVec);
|
||||
mlir::DenseIntElementsAttr edgePaddingHigh =
|
||||
|
@ -614,7 +620,7 @@ public:
|
|||
mlir::DenseIntElementsAttr interiorPadding =
|
||||
rewriter.getI64VectorAttr(interiorPaddingVec);
|
||||
|
||||
auto paddedOutput = rewriter.create<mhlo::PadOp>(
|
||||
auto paddedOutput = rewriter.create<stablehlo::PadOp>(
|
||||
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
|
||||
edgePaddingHigh, interiorPadding);
|
||||
|
||||
|
@ -628,22 +634,22 @@ public:
|
|||
ArrayRef<int64_t> dilation, int64_t groups) const {
|
||||
int64_t nDims = outType.getRank();
|
||||
|
||||
// Get mhlo::ConvolutionOp attributes
|
||||
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
|
||||
// Get stablehlo::ConvolutionOp attributes
|
||||
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(stride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stride);
|
||||
std::vector<int64_t> mhloPaddingVec;
|
||||
std::vector<int64_t> stablehloPaddingVec;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
mhloPaddingVec.emplace_back(padding[i]);
|
||||
mhloPaddingVec.emplace_back(padding[i]);
|
||||
stablehloPaddingVec.emplace_back(padding[i]);
|
||||
stablehloPaddingVec.emplace_back(padding[i]);
|
||||
}
|
||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
||||
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPaddingVec);
|
||||
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
|
||||
stablehloPaddingVec);
|
||||
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
dilation);
|
||||
|
@ -651,8 +657,8 @@ public:
|
|||
for (int64_t i = 2; i < nDims; i++) {
|
||||
spatialDimensions.emplace_back(i);
|
||||
}
|
||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
mhlo::ConvDimensionNumbersAttr::get(
|
||||
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
stablehlo::ConvDimensionNumbersAttr::get(
|
||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||
/*inputFeatureDimension=*/1,
|
||||
/*inputSpatialDimensions=*/spatialDimensions,
|
||||
|
@ -662,17 +668,18 @@ public:
|
|||
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
|
||||
/*outputSpatialDimensions=*/spatialDimensions);
|
||||
|
||||
// mhlo::ConvolutionOp's optional attributes, leave them as default
|
||||
DenseIntElementsAttr mhloLhsDilation;
|
||||
// stablehlo::ConvolutionOp's optional attributes, leave them as default
|
||||
DenseIntElementsAttr stablehloLhsDilation;
|
||||
DenseElementsAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
||||
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
||||
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
|
||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
||||
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||
op->getLoc(), outType, input, weight, stablehloWindowStride,
|
||||
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||
precisionConfig);
|
||||
|
||||
return mhloConvOp.getResult();
|
||||
return stablehloConvOp.getResult();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -754,21 +761,22 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
Value mhloConvResult;
|
||||
Value stablehloConvResult;
|
||||
if (transposed) {
|
||||
mhloConvResult = convertTransposedConv(
|
||||
stablehloConvResult = convertTransposedConv(
|
||||
op, rewriter, outTy, input, weight, stride, padding, dilation,
|
||||
outputPadding, groups, needHandleOutputPadding);
|
||||
} else {
|
||||
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
|
||||
stride, padding, dilation, groups);
|
||||
stablehloConvResult =
|
||||
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
|
||||
dilation, groups);
|
||||
}
|
||||
|
||||
auto bias = adaptor.getBias();
|
||||
|
||||
// No bias provided
|
||||
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
|
||||
rewriter.replaceOp(op, mhloConvResult);
|
||||
rewriter.replaceOp(op, stablehloConvResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -790,21 +798,21 @@ public:
|
|||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||
|
||||
const auto &options = getOptions();
|
||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = hlo::promoteType(rewriter, bias, outTy);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
|
||||
bias, bcastDimensions);
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, outTy, stablehloConvResult, bias, bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -28,7 +29,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -40,14 +41,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getZero(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getLargest(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
op->emitError("unimplemented lowering in AtenPoolingOp");
|
||||
|
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
stablehloPadding);
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
|
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result =
|
||||
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
|
||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
|
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
stablehloPadding);
|
||||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
auto initIndexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicIotaOp>(
|
||||
.create<stablehlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(initIndexShapeForType,
|
||||
rewriter.getI64Type()),
|
||||
|
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
auto indexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||
initIndexTensor, inputShapeTensor)
|
||||
.getResult();
|
||||
|
||||
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||
|
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::EQ);
|
||||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// Get smaller index if compared values are equal.
|
||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx =
|
||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<mhlo::ReturnOp>(
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
|
||||
|
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
|
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
// Use kernel size as the divisor
|
||||
if (countIncludePad) {
|
||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
||||
Value divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
||||
divisor = hlo::promoteType(rewriter, divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
// Use another stablehlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeVec =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||
|
||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
||||
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
||||
windowDilations, pad);
|
||||
|
@ -522,11 +526,11 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
@ -560,33 +564,33 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
mhloKernelSize[dim] = inputShape[dim];
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
mhloPadding[dim * 2] = inputShape[dim] - 1;
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
stablehloKernelSize[dim] = inputShape[dim];
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
stablehloPadding[dim * 2] = inputShape[dim] - 1;
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
|
@ -604,17 +608,17 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
|
@ -0,0 +1,69 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace torch_to_stablehlo {
|
||||
|
||||
struct TorchToStablehloOptions {
|
||||
bool enableStaticShape = false;
|
||||
size_t dimSizeIndexBits = 64;
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||
const TorchToStablehloOptions &options)
|
||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||
this->options = options;
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||
}
|
||||
const TorchToStablehloOptions &getOptions() const { return options; }
|
||||
|
||||
private:
|
||||
TorchToStablehloOptions options;
|
||||
};
|
||||
|
||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToStablehloOptions &options);
|
||||
void populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
|
||||
} // namespace torch_to_stablehlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -25,7 +26,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getZero(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getLargest(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
return std::nullopt;
|
||||
Value initIndex;
|
||||
if (dimSizeIndexBits == 32) {
|
||||
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||
} else {
|
||||
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||
|
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
||||
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape,
|
||||
rewriter.getIntegerType(dimSizeIndexBits)),
|
||||
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), ValueRange{input, indexTensor},
|
||||
ValueRange{
|
||||
initValue,
|
||||
|
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
},
|
||||
dimensions);
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
|
||||
// Add block arguments
|
||||
auto blockValArgumentType =
|
||||
|
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::EQ);
|
||||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// get smaller index value if compared nums are equal.
|
||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx =
|
||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<mhlo::ReturnOp>(
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
return mhloReduceOp.getResults();
|
||||
return stablehloReduceOp.getResults();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenArgmaxOp to MHLO");
|
||||
"AtenArgmaxOp to StableHLO");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
|
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, mhloReduceResults[1]);
|
||||
rewriter.replaceOp(op, stablehloReduceResults[1]);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxDimOp to MHLO");
|
||||
"AtenMaxDimOp to StableHLO");
|
||||
}
|
||||
|
||||
RankedTensorType valResultType = getTypeConverter()
|
||||
|
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
|
||||
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
|
||||
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||
outShapeTensor);
|
||||
auto stablehloReduceIndexResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
|
||||
rewriter.replaceOp(op,
|
||||
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenSumOp to MHLO");
|
||||
"AtenSumOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
|
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
}
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxOp to MHLO");
|
||||
"AtenMaxOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
|
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value maxResult = rewriter.create<mhlo::MaxOp>(
|
||||
Value maxResult = rewriter.create<stablehlo::MaxOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenSumDimIntListOp to MHLO");
|
||||
"AtenSumDimIntListOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> inputDims;
|
||||
|
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Region ®ion = mhloReduceOp.getBody();
|
||||
Region ®ion = stablehloReduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
|
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = getOptions();
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
mhloReduceOp.getResult(0), outShapeTensor);
|
||||
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenFrobeniusNormDimOp
|
||||
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
|
||||
// + mhlo.sqrt
|
||||
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
|
||||
// dims)
|
||||
// + stablehlo.sqrt
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
const TorchToMhloOptions &options = getOptions();
|
||||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
// Sort the dims in ascending order, making the conversion
|
||||
// Sort the dims in ascending order, making the conversion
|
||||
// stable with unordered dims.
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
|
@ -624,14 +642,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
op, "non-const bool `keepdim` is not supported");
|
||||
}
|
||||
|
||||
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input);
|
||||
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
|
||||
|
||||
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
||||
if (!initValue) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), squareOp.getResult(), initValue,
|
||||
rewriter.getI64TensorAttr(dims));
|
||||
|
||||
|
@ -649,30 +667,32 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument,
|
||||
secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
auto addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), firstArgument, secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
}
|
||||
|
||||
auto output =
|
||||
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output,
|
||||
outShapeTensor);
|
||||
return success();
|
||||
|
@ -682,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
|
@ -7,11 +7,12 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include <numeric>
|
||||
|
@ -21,27 +22,27 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace hlo {
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Create a 64-bit float constant operator from a double
|
||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val) {
|
||||
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val) {
|
||||
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
|||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
|||
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
||||
auto const_type = RankedTensorType::get(dshape, dtype);
|
||||
auto const_attr = SplatElementsAttr::get(const_type, val);
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype) {
|
||||
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value scalarValue, Type dtype) {
|
||||
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ArrayRef<Value>{scalarValue});
|
||||
auto dtype_tensor =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||
return rewriter.create<mhlo::ReshapeOp>(
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||
return rewriter.create<stablehlo::ReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
||||
dtype_tensor);
|
||||
}
|
||||
|
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
|||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
|
||||
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
|
||||
input);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promoted_type =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
input =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
||||
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
|
||||
input);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> inShape = in_type.getShape();
|
||||
|
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
||||
rewriter.getI64Type()),
|
||||
bcastDims);
|
||||
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
|
||||
input, bcast_attr);
|
||||
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
op->getLoc(), outType, input, bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
|||
const APFloat &constant, Value shape,
|
||||
TensorType outType) {
|
||||
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
||||
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
|
||||
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
|
||||
return rewriter
|
||||
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
|
||||
rewriter.getI64TensorAttr({}))
|
||||
.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
|
||||
.getResult();
|
||||
}
|
||||
} // namespace mhlo
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -18,22 +18,22 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace hlo {
|
||||
|
||||
using mlir::ConversionPatternRewriter;
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
||||
// Create a 64-bit float constant operator from a double
|
||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val);
|
||||
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val);
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
|
||||
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
|
||||
template <typename T>
|
||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
|
@ -42,8 +42,8 @@ template <typename T>
|
|||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
||||
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype);
|
||||
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value scalarValue, Type dtype);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||
|
||||
|
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType);
|
||||
} // namespace mhlo
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
|
@ -7,17 +7,18 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
namespace {
|
||||
|
||||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||
class ConvertTorchToStablehlo
|
||||
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
|
||||
public:
|
||||
ConvertTorchToMhlo() = default;
|
||||
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
|
||||
ConvertTorchToStablehlo() = default;
|
||||
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
|
||||
this->enableStaticShape = enableStaticShape;
|
||||
this->enableI32Index = enableI32Index;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
registry.insert<stablehlo::StablehloDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
|
@ -48,7 +50,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
||||
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
@ -57,20 +59,20 @@ public:
|
|||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
|
||||
enableI32Index ? 32u : 64u};
|
||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
torch_to_stablehlo::TorchToStablehloOptions options{
|
||||
enableStaticShape, enableI32Index ? 32u : 64u};
|
||||
torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
@ -82,13 +84,13 @@ public:
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass() {
|
||||
return std::make_unique<ConvertTorchToMhlo>(false, false);
|
||||
mlir::torch::createConvertTorchToStablehloPass() {
|
||||
return std::make_unique<ConvertTorchToStablehlo>(false, false);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
|
||||
bool enableI32Index) {
|
||||
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
|
||||
enableI32Index);
|
||||
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
|
||||
bool enableI32Index) {
|
||||
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
|
||||
enableI32Index);
|
||||
}
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -28,7 +29,7 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||
|
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
|||
auto stridesTensor =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
||||
|
||||
return rewriter.create<mhlo::RealDynamicSliceOp>(
|
||||
return rewriter.create<stablehlo::RealDynamicSliceOp>(
|
||||
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
||||
}
|
||||
|
||||
|
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
||||
}
|
||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -179,7 +180,7 @@ public:
|
|||
auto loc = op.getLoc();
|
||||
auto newRank = dimSizes.size();
|
||||
if (newRank == 0 || rankType.getRank() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -214,17 +215,18 @@ public:
|
|||
numel);
|
||||
|
||||
if (dimSizes.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
return success();
|
||||
}
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
||||
loc, mhloShape.getType(), numel, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
Value stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
|
||||
loc, stablehloShape.getType(), numel, stablehloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
dims.push_back(r);
|
||||
}
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
std::iota(dims.begin(), dims.end(), 0);
|
||||
dims.erase(dims.begin() + dim);
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
|
||||
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
if (failed(unsqzTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to create unsqueezed tensor");
|
||||
|
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
|
|||
TorchMLIRTorchConversionToMLProgram
|
||||
MLIRMemRefTransforms)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND LinkedLibs ChloPasses)
|
||||
endif()
|
||||
|
||||
|
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
Passes.cpp
|
||||
VerifyLinalgOnTensorsBackendContract.cpp
|
||||
VerifyTosaBackendContract.cpp
|
||||
VerifyMhloBackendContract.cpp
|
||||
VerifyStablehloBackendContract.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||
|
|
|
@ -21,9 +21,8 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#endif
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
|
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
"Pipeline lowering torch backend contract to TOSA backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
|
||||
"torch-backend-to-mhlo-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to MHLO backend "
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
mlir::PassPipelineRegistration<
|
||||
TorchConversion::StablehloBackendPipelineOptions>(
|
||||
"torch-backend-to-stablehlo-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to StableHLO backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToMhloBackendPipeline);
|
||||
TorchConversion::createTorchBackendToStablehloBackendPipeline);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const TorchConversion::MhloBackendPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
|
||||
const TorchConversion::StablehloBackendPipelineOptions &options) {
|
||||
// Generate Stablehlo ops.
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
||||
options.enableStaticShape, options.enableI32Index));
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
|
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
|||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
|
||||
// Convert CHLO ops to MHLO ops
|
||||
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// MHLO backend contract.
|
||||
// StableHLO backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
// Verify that we have lowered to the form that MHLO backends
|
||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
|
||||
|
||||
// Verify that we have lowered to Stablehlo and Chlo ops.
|
||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
|
@ -18,6 +17,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -25,17 +25,15 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
namespace {
|
||||
class VerifyMhloBackendContractPass
|
||||
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
|
||||
class VerifyStablehloBackendContractPass
|
||||
: public VerifyStablehloBackendContractBase<
|
||||
VerifyStablehloBackendContractPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto module = getOperation();
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](Type type) -> Type {
|
||||
auto elemTy = type;
|
||||
if (isa<TensorType>(type)) {
|
||||
if (isa<TensorType>(type))
|
||||
elemTy = type.cast<TensorType>().getElementType();
|
||||
}
|
||||
if (BaseMemRefType::isValidElementType(elemTy))
|
||||
return type;
|
||||
return nullptr;
|
||||
|
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
|
|||
|
||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
|
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
|
|||
// Shape operations.
|
||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||
|
||||
target.addLegalDialect<mhlo::MhloDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect>();
|
||||
target.addLegalDialect<stablehlo::StablehloDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to the MHLO backend contract. "
|
||||
"See dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
|
||||
return std::make_unique<VerifyMhloBackendContractPass>();
|
||||
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
|
||||
return std::make_unique<VerifyStablehloBackendContractPass>();
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
@ -20,6 +20,10 @@
|
|||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
#include "torch-mlir/RefBackend/Passes.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#endif
|
||||
|
||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||
registry.insert<mlir::func::FuncDialect>();
|
||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||
|
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
|
|||
mlir::torch::registerConversionPasses();
|
||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||
mlir::torch::TMTensor::registerPasses();
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
}
|
||||
|
|
|
@ -44,9 +44,9 @@ class OutputType(Enum):
|
|||
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||
TOSA = "tosa"
|
||||
|
||||
# This output type consists of `mhlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to MHLO.
|
||||
MHLO = "mhlo"
|
||||
# This output type consists of `stablehlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||
STABLEHLO = "stablehlo"
|
||||
|
||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||
# for end-users, but can be convenient for development or reporting bugs.
|
||||
|
@ -242,7 +242,7 @@ class ExampleArgs:
|
|||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.MHLO: [],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
||||
|
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
|
|||
|
||||
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
||||
# output type because the other output types actually invoke their
|
||||
# respective backends (Linalg, TOSA, or MHLO), and those backends have
|
||||
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
|
||||
# very specific requirements about the ops which are legal.
|
||||
# See `BACKEND_LEGAL_OPS` for more details.
|
||||
if backend_legal_ops is not None:
|
||||
|
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
print(mb.module)
|
||||
return mb.module
|
||||
|
||||
elif output_type == OutputType.MHLO:
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"builtin.module(torch-backend-to-mhlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> MHLO Backend IR")
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("MHLO Backend IR")
|
||||
print("StableHLO Backend IR")
|
||||
print(mb.module)
|
||||
return mb.module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
|
|
|
@ -7,6 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
|
|||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
from .mhlo_backend import MhloBackendTestConfig
|
||||
from .stablehlo_backend import StablehloBackendTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
from .torchdynamo import TorchDynamoTestConfig
|
||||
|
|
|
@ -8,12 +8,8 @@ from typing import Any
|
|||
import torch
|
||||
import torch_mlir
|
||||
|
||||
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
|
||||
from torch_mlir_e2e_test.framework import (
|
||||
TestConfig,
|
||||
Trace,
|
||||
TraceItem
|
||||
)
|
||||
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
|
@ -21,20 +17,20 @@ from .utils import (
|
|||
)
|
||||
|
||||
|
||||
class MhloBackendTestConfig(TestConfig):
|
||||
class StablehloBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||
|
||||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
def __init__(self, backend: MhloBackend):
|
||||
|
||||
def __init__(self, backend: StablehloBackend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(
|
||||
program, example_args, output_type="mhlo")
|
||||
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
|
@ -46,7 +42,6 @@ class MhloBackendTestConfig(TestConfig):
|
|||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
output = recursively_convert_from_numpy(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
|
@ -10,29 +10,30 @@ import torch
|
|||
|
||||
from torch_mlir.ir import Module
|
||||
|
||||
# A type shared between the result of `MhloBackend.compile` and the
|
||||
# input to `MhloBackend.load`. Each backend will likely have a
|
||||
# A type shared between the result of `StablehloBackend.compile` and the
|
||||
# input to `StablehloBackend.load`. Each backend will likely have a
|
||||
# different definition of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
Invoker = TypeVar("Invoker")
|
||||
|
||||
|
||||
class MhloBackend(abc.ABC):
|
||||
"""The interface to an MHLO backend.
|
||||
class StablehloBackend(abc.ABC):
|
||||
"""The interface to an StableHLO backend.
|
||||
|
||||
Backends are recommended to raise meaningful exceptions in case of error,
|
||||
ideally with easy reproduction instructions.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
||||
The module adheres to the MHLO backend contract
|
||||
(see the VerifyMhloBackendContract pass).
|
||||
The module adheres to the StableHLO backend contract
|
||||
(see the VerifyStablehloBackendContract pass).
|
||||
|
||||
The compiled artifact can be any type, but must be correctly
|
||||
interpreted by the `load` method.
|
|
@ -7,28 +7,32 @@ from torch_mlir.ir import *
|
|||
from torch_mlir.passmanager import *
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
from .abc import MhloBackend
|
||||
from .abc import StablehloBackend
|
||||
|
||||
__all__ = [
|
||||
"LinalgOnTensorsMhloBackend",
|
||||
"LinalgOnTensorsStablehloBackend",
|
||||
]
|
||||
|
||||
class LinalgOnTensorsMhloBackend(MhloBackend):
|
||||
"""Main entry-point for the linalg-on-tensors based MHLO backend.
|
||||
|
||||
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
||||
"""Main entry-point for the linalg-on-tensors based StableHLO backend.
|
||||
|
||||
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module that satisfied the MHLO backend contract.
|
||||
"""Compiles an imported module that satisfied the StableHLO backend contract.
|
||||
|
||||
Args:
|
||||
imported_module: The MLIR module consisting of funcs in the MHLO
|
||||
imported_module: The MLIR module consisting of funcs in the StableHLO
|
||||
dialect.
|
||||
Returns:
|
||||
An opaque, backend specific compiled artifact object that can be
|
||||
|
@ -36,8 +40,9 @@ class LinalgOnTensorsMhloBackend(MhloBackend):
|
|||
"""
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))",
|
||||
"Lowering MLIR-HLO to Linalg-on-Tensors")
|
||||
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
|
||||
"Lowering StableHLO to Linalg-on-Tensors",
|
||||
)
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
def load(self, module):
|
|
@ -1,7 +1,7 @@
|
|||
llvm_canonicalize_cmake_booleans(
|
||||
MLIR_ENABLE_BINDINGS_PYTHON
|
||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
|
||||
TORCH_MLIR_ENABLE_MHLO
|
||||
TORCH_MLIR_ENABLE_STABLEHLO
|
||||
)
|
||||
|
||||
configure_lit_site_cfg(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
|
||||
// -----
|
||||
|
@ -7,7 +7,7 @@
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||
|
@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64>
|
||||
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
|
||||
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||
|
@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
|||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
|
||||
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
|
||||
|
@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
|
|||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
||||
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
|
@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc
|
|||
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32>
|
||||
func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> {
|
||||
|
@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
|
|||
// CHECK-LABEL: func.func @torch.aten.batch_norm$training(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %true = torch.constant.bool true
|
||||
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
|
@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
|||
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
|
||||
|
@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
|||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
||||
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
|
@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
|||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||
// CHECK: %true = torch.constant.bool true
|
||||
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
||||
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||
|
@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
|
|||
// CHECK-LABEL: func @torch.aten.native_layer_norm(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32>
|
||||
// CHECK: %int4 = torch.constant.int 4
|
||||
// CHECK: %int5 = torch.constant.int 5
|
||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||
// CHECK: %true = torch.constant.bool true
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
|
||||
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
|
||||
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||
|
@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
|
|||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc
|
|||
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.gelu(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -7,12 +7,12 @@
|
|||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
|||
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
// CHECK-LABEL: func.func @torch.aten.neg$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
|
|||
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||
|
@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
|||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||
|
@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
|||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||
|
@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
|||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
|
||||
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||
|
@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
|
@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||
|
@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>
|
|||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
||||
|
@ -10,8 +10,8 @@
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
||||
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
||||
|
@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
|
|||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x1x?xf32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
|
||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
||||
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$basic$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
||||
|
@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
|
|||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
||||
|
@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
|
|||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
||||
|
@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
|
|||
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
||||
|
@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
|
|||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
||||
|
@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
|
|||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
||||
|
@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
|
|||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||
|
@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||
|
@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
||||
|
@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
|||
// CHECK-LABEL: func.func @torch.aten.matmul$proj(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
|
@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
|||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
||||
|
@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc
|
|||
// CHECK-LABEL: func.func @torch.aten.mm$proj(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
||||
|
@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
|
|||
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[T_13:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
||||
// CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
|
||||
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
|
|||
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
||||
// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
|
||||
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
|
||||
// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
|
|||
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
|
||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]])
|
||||
// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
|
||||
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
|
||||
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
|
||||
|
@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
|
|||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
|
||||
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
||||
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
|
||||
|
@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
|
|||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
|
||||
// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||
// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
|
||||
// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
|
||||
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
|
||||
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
|
||||
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
|
||||
|
@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
|||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32>
|
||||
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
|
||||
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
|
||||
|
@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
|||
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
|
||||
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
|
||||
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
|
||||
// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
|
||||
// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
|
||||
// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
|
||||
// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
|
||||
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
|
||||
// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
|
||||
// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]])
|
||||
// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
|
||||
// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]])
|
||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
||||
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
if not config.enable_mhlo:
|
||||
if not config.enable_stablehlo:
|
||||
config.unsupported = True
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -13,11 +13,11 @@
|
|||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||
// CHECK: })
|
||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
|
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
|
@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
|
||||
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
||||
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
||||
// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||
// CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||
|
@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: mhlo.return %[[IVAL_2]] : tensor<f32>
|
||||
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
|
||||
|
@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
|||
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
||||
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
||||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||
// CHECK: mhlo.return %[[IVAL_5]] : tensor<f32>
|
||||
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
|
@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
||||
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||
// CHECK: mhlo.return %[[T10]] : tensor<f32>
|
||||
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -42,7 +42,7 @@
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
|
||||
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
|
||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
|
||||
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
|
||||
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
|
||||
|
@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
|
||||
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
|
||||
|
@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
|
||||
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
|
||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
|
||||
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
|
||||
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
|
||||
|
@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
|||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
|
||||
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
|
||||
|
@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
|||
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
|
||||
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
|
||||
|
@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
|
||||
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
|
||||
// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
||||
// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
||||
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
||||
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
|
||||
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
|
||||
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
|
||||
|
@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
|
||||
func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
|
@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
||||
|
@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
|
|||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
||||
|
@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
|
|||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
||||
|
@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
|
|||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
|
||||
// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
|
||||
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
||||
|
@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
|
|||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
||||
|
@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
||||
|
@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
|
||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
|
||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
|
||||
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
||||
|
|
|
@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@"
|
|||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||
config.python_executable = "@Python3_EXECUTABLE@"
|
||||
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
|
||||
config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@
|
||||
config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@
|
||||
|
||||
import lit.llvm
|
||||
lit.llvm.initialize(lit_config, config)
|
||||
|
|
|
@ -268,7 +268,7 @@ gentbl_cc_library(
|
|||
(
|
||||
[
|
||||
"-gen-pass-decls",
|
||||
"-DTORCH_MLIR_ENABLE_MHLO",
|
||||
"-DTORCH_MLIR_ENABLE_STABLEHLO",
|
||||
],
|
||||
"include/torch-mlir/Conversion/Passes.h.inc",
|
||||
),
|
||||
|
@ -434,13 +434,13 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "TorchMLIRTorchToMhlo",
|
||||
name = "TorchMLIRTorchToStablehlo",
|
||||
srcs = glob([
|
||||
"lib/Conversion/*.h",
|
||||
"lib/Conversion/TorchToMhlo/*.h",
|
||||
"lib/Conversion/TorchToMhlo/*.cpp",
|
||||
"lib/Conversion/TorchToStablehlo/*.h",
|
||||
"lib/Conversion/TorchToStablehlo/*.cpp",
|
||||
]),
|
||||
hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]),
|
||||
hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]),
|
||||
strip_include_prefix = "include",
|
||||
deps = [
|
||||
":TorchMLIRConversionPassesIncGen",
|
||||
|
@ -465,8 +465,8 @@ cc_library(
|
|||
":TorchMLIRTorchConversionToMLProgram",
|
||||
":TorchMLIRTorchToArith",
|
||||
":TorchMLIRTorchToLinalg",
|
||||
":TorchMLIRTorchToMhlo",
|
||||
":TorchMLIRTorchToSCF",
|
||||
":TorchMLIRTorchToStablehlo",
|
||||
":TorchMLIRTorchToTMTensor",
|
||||
":TorchMLIRTorchToTosa",
|
||||
],
|
||||
|
@ -489,8 +489,8 @@ cc_library(
|
|||
":TorchMLIRTorchPasses",
|
||||
":TorchMLIRTorchToArith",
|
||||
":TorchMLIRTorchToLinalg",
|
||||
":TorchMLIRTorchToMhlo",
|
||||
":TorchMLIRTorchToSCF",
|
||||
":TorchMLIRTorchToStablehlo",
|
||||
":TorchMLIRTorchToTMTensor",
|
||||
":TorchMLIRTorchToTosa",
|
||||
"@llvm-project//mlir:ConversionPasses",
|
||||
|
|
|
@ -23,7 +23,7 @@ expand_template(
|
|||
# All disabled, but required to substituted because they are not in quotes.
|
||||
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
|
||||
"@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0",
|
||||
"@TORCH_MLIR_ENABLE_MHLO@": "0",
|
||||
"@TORCH_MLIR_ENABLE_STABLEHLO@": "0",
|
||||
},
|
||||
template = "lit.site.cfg.py.in",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue