mirror of https://github.com/llvm/torch-mlir
mhlo: migrate conversion to stablehlo (#1840)
This patch replaces all MHLO operations with their StableHLO counterparts and adds a validation pass to ensure that no MHLO operations remain before translating all Stablehlo operations to the MHLO dialect for further lowering to the Linalg dialect. This patch also updates all lit tests so that they refer to the `convert-torch-to-stablehlo` pass and so that they check for StableHLO operations.pull/1851/head
parent
ed9d8d1fb7
commit
711646d095
|
@ -113,7 +113,7 @@ jobs:
|
||||||
-DLLVM_USE_HOST_TOOLS=ON \
|
-DLLVM_USE_HOST_TOOLS=ON \
|
||||||
-DLLVM_ENABLE_ZSTD=OFF \
|
-DLLVM_ENABLE_ZSTD=OFF \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
-DTORCH_MLIR_ENABLE_MHLO=OFF \
|
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
|
||||||
-DTORCH_MLIR_ENABLE_LTC=OFF \
|
-DTORCH_MLIR_ENABLE_LTC=OFF \
|
||||||
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
|
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
|
||||||
-DMACOSX_DEPLOYMENT_TARGET=12.0 \
|
-DMACOSX_DEPLOYMENT_TARGET=12.0 \
|
||||||
|
|
|
@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||||
|
@ -128,8 +128,8 @@ else()
|
||||||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (TORCH_MLIR_ENABLE_MHLO)
|
if (TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
set(MHLO_BUILD_EMBEDDED ON)
|
set(STABLEHLO_BUILD_EMBEDDED ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
|
|
|
@ -267,8 +267,8 @@ function test_in_tree() {
|
||||||
echo ":::: Run Linalg e2e integration tests"
|
echo ":::: Run Linalg e2e integration tests"
|
||||||
python -m e2e_testing.main --config=linalg -v
|
python -m e2e_testing.main --config=linalg -v
|
||||||
|
|
||||||
echo ":::: Run MHLO e2e integration tests"
|
echo ":::: Run StableHLO e2e integration tests"
|
||||||
python -m e2e_testing.main --config=mhlo -v
|
python -m e2e_testing.main --config=stablehlo -v
|
||||||
|
|
||||||
echo ":::: Run TOSA e2e integration tests"
|
echo ":::: Run TOSA e2e integration tests"
|
||||||
python -m e2e_testing.main --config=tosa -v
|
python -m e2e_testing.main --config=tosa -v
|
||||||
|
|
|
@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
|
||||||
|
|
||||||
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
|
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
|
||||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||||
|
|
||||||
The terms "frontend" and "backend" are highly overloaded in any compiler
|
The terms "frontend" and "backend" are highly overloaded in any compiler
|
||||||
project, but frequently in Torch-MLIR this is the meaning that they have.
|
project, but frequently in Torch-MLIR this is the meaning that they have.
|
||||||
Sometimes "frontend" can mean something even further up the stack, such as
|
Sometimes "frontend" can mean something even further up the stack, such as
|
||||||
something in PyTorch itself. When there is ambiguity we will refer to this as
|
something in PyTorch itself. When there is ambiguity we will refer to this as
|
||||||
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
|
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
|
||||||
sitting below Linalg-on-Tensors, TOSA, or MHLO.
|
sitting below Linalg-on-Tensors, TOSA, or StableHLO.
|
||||||
|
|
||||||
## The `torch` dialect
|
## The `torch` dialect
|
||||||
|
|
||||||
|
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
|
||||||
|
|
||||||
The backend contract is a normalized form of the `torch` dialect with a set of
|
The backend contract is a normalized form of the `torch` dialect with a set of
|
||||||
properties that make it easy to lower into various forms such as
|
properties that make it easy to lower into various forms such as
|
||||||
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
|
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
|
||||||
box. The primary guarantees that we provide Torch-MLIR's backends are:
|
the box. The primary guarantees that we provide Torch-MLIR's backends are:
|
||||||
|
|
||||||
- All tensors have been converted to value semantics.
|
- All tensors have been converted to value semantics.
|
||||||
- All tensors have at least a known number of dimensions (i.e. rank), and
|
- All tensors have at least a known number of dimensions (i.e. rank), and
|
||||||
|
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
|
||||||
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
|
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
|
||||||
`tensor`, etc.)
|
`tensor`, etc.)
|
||||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||||
|
|
||||||
### The Linalg Backend (Linalg-on-Tensors)
|
### The Linalg Backend (Linalg-on-Tensors)
|
||||||
|
|
||||||
|
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
|
||||||
- It is extremely solid with static shapes (and many of its users only care
|
- It is extremely solid with static shapes (and many of its users only care
|
||||||
about static shapes, so that's fine).
|
about static shapes, so that's fine).
|
||||||
|
|
||||||
### The MHLO Backend
|
### The StableHLO Backend
|
||||||
|
|
||||||
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
|
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
|
||||||
|
|
||||||
The MHLO backend was the third backend that we added, and it offers a reasonable
|
The StableHLO backend was the third backend that we added, and it offers a
|
||||||
blend of the benefits of the other two.
|
reasonable blend of the benefits of the other two.
|
||||||
- It is a coarse-grained named-op approach.
|
- It is a coarse-grained named-op approach.
|
||||||
- It has a pretty clear spec for most of the ops (with a bit of mental
|
- It has a pretty clear spec for most of the ops (with a bit of mental
|
||||||
translation and hoping that MHLO is the same as HLO):
|
translation and hoping that StableHLO is the same as HLO):
|
||||||
https://www.tensorflow.org/xla/operation_semantics
|
https://www.tensorflow.org/xla/operation_semantics
|
||||||
- It functionally supports dynamic shapes (though not as coherent and consistent
|
- It functionally supports dynamic shapes (though not as coherent and consistent
|
||||||
as Linalg-on-Tensors, and the dynamic shape support falls outside the
|
as Linalg-on-Tensors, and the dynamic shape support falls outside the
|
||||||
|
@ -317,7 +317,7 @@ blend of the benefits of the other two.
|
||||||
example, TOSA limits (for highly considered reasons) the number of dimensions
|
example, TOSA limits (for highly considered reasons) the number of dimensions
|
||||||
that certain operators can handle to 1D-4D, when from a purely algebraic
|
that certain operators can handle to 1D-4D, when from a purely algebraic
|
||||||
perspective there isn't a good reason to not be more general. Similarly, more
|
perspective there isn't a good reason to not be more general. Similarly, more
|
||||||
general forms of reduction and scatter also fall into MHLO nicely while
|
general forms of reduction and scatter also fall into StableHLO nicely while
|
||||||
TOSA's principles tend to bias it away from that.
|
TOSA's principles tend to bias it away from that.
|
||||||
|
|
||||||
### Backend Implementation
|
### Backend Implementation
|
||||||
|
@ -433,8 +433,9 @@ filling in some corners missing upstream and
|
||||||
to pull together upstream functionality into a working system.
|
to pull together upstream functionality into a working system.
|
||||||
|
|
||||||
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
|
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
|
||||||
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
|
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
|
||||||
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
|
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
|
||||||
|
RefBackend.
|
||||||
|
|
||||||
The RefBackend is absolutely not suitable for any production use case. It leaks
|
The RefBackend is absolutely not suitable for any production use case. It leaks
|
||||||
memory, doesn't support any error handling, performs no optimizations, and
|
memory, doesn't support any error handling, performs no optimizations, and
|
||||||
|
|
|
@ -34,7 +34,7 @@ and Clang's
|
||||||
- Eric Kunze (@eric-k256)
|
- Eric Kunze (@eric-k256)
|
||||||
- Suraj Sudhir (@sjarus)
|
- Suraj Sudhir (@sjarus)
|
||||||
|
|
||||||
### TorchToMHLO
|
### TorchToStablehlo
|
||||||
|
|
||||||
- Tianyo Kwok (@tanyokwok)
|
- Tianyo Kwok (@tanyokwok)
|
||||||
- Ziheng Jiang (@ZihengJiang)
|
- Ziheng Jiang (@ZihengJiang)
|
||||||
|
|
|
@ -139,7 +139,7 @@ Ex:
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||||
```
|
```
|
||||||
|
|
||||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
|
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||||
|
|
||||||
## Jupyter
|
## Jupyter
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ the ecosystem are:
|
||||||
|
|
||||||
- The frontend work required to lower TorchScript to the backend contract.
|
- The frontend work required to lower TorchScript to the backend contract.
|
||||||
- The irregular support surface area of the large number of PyTorch ops across
|
- The irregular support surface area of the large number of PyTorch ops across
|
||||||
the Linalg, TOSA, and MHLO backends.
|
the Linalg, TOSA, and StableHLO backends.
|
||||||
|
|
||||||
Most of this document describes long-term ecosystem changes that will address
|
Most of this document describes long-term ecosystem changes that will address
|
||||||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||||||
|
@ -108,7 +108,7 @@ more advanced).
|
||||||
### Refactoring the backend
|
### Refactoring the backend
|
||||||
|
|
||||||
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
|
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
|
||||||
TOSA, and MHLO. These backends take IR in the backend contract form (see
|
TOSA, and StableHLO. These backends take IR in the backend contract form (see
|
||||||
[architecture.md](architecture.md)) and lowers them to the respective dialects.
|
[architecture.md](architecture.md)) and lowers them to the respective dialects.
|
||||||
Today, each backend is implemented completely independently. This leads to
|
Today, each backend is implemented completely independently. This leads to
|
||||||
duplication and irregularity across the backends.
|
duplication and irregularity across the backends.
|
||||||
|
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
|
||||||
forward-looking efforts that intersect with this effort:
|
forward-looking efforts that intersect with this effort:
|
||||||
|
|
||||||
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
|
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
|
||||||
initially forked from MHLO which intends to create a stable support surface
|
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
|
||||||
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
|
attractive to have "almost all" models bottleneck through a stable interface
|
||||||
complete op set, so it is very attractive to have "almost all" models
|
like StableHLO. StableHLO is currently under relatively early development,
|
||||||
bottleneck through a stable interface like StableHLO. StableHLO is currently
|
but already delivers on many of the goals of stability.
|
||||||
under relatively early development, but already delivers on many of the goals
|
|
||||||
of stability.
|
|
||||||
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
|
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
|
||||||
which could serve a role very similar to MHLO, while providing community
|
which could serve a role very similar to MHLO, while providing community
|
||||||
ownership. TCP is still in early planning phases, but there is strong
|
ownership. TCP is still in early planning phases, but there is strong
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir_e2e_test.configs import (
|
from torch_mlir_e2e_test.configs import (
|
||||||
LazyTensorCoreTestConfig,
|
LazyTensorCoreTestConfig,
|
||||||
LinalgOnTensorsBackendTestConfig,
|
LinalgOnTensorsBackendTestConfig,
|
||||||
MhloBackendTestConfig,
|
StablehloBackendTestConfig,
|
||||||
NativeTorchTestConfig,
|
NativeTorchTestConfig,
|
||||||
TorchScriptTestConfig,
|
TorchScriptTestConfig,
|
||||||
TosaBackendTestConfig,
|
TosaBackendTestConfig,
|
||||||
|
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||||
parser.add_argument("-c", "--config",
|
parser.add_argument("-c", "--config",
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -42,7 +42,7 @@ def _get_argparse():
|
||||||
help=f"""
|
help=f"""
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
||||||
"mhlo": run through torch-mlir"s default MHLO backend.
|
"stablehlo": run through torch-mlir"s default StableHLO backend.
|
||||||
"tosa": run through torch-mlir"s default TOSA backend.
|
"tosa": run through torch-mlir"s default TOSA backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
|
@ -80,9 +80,9 @@ def main():
|
||||||
if args.config == "tosa":
|
if args.config == "tosa":
|
||||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||||
if args.config == "mhlo":
|
if args.config == "stablehlo":
|
||||||
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
|
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||||
xfail_set = all_test_unique_names - MHLO_PASS_SET
|
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||||
elif args.config == "native_torch":
|
elif args.config == "native_torch":
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
xfail_set = {}
|
xfail_set = {}
|
||||||
|
|
|
@ -87,8 +87,10 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"StdCorrectionKeepDimModule_basic",
|
"StdCorrectionKeepDimModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AddSizeIntModule_basic",
|
||||||
|
"AddSizeIntNegDimModule_basic",
|
||||||
"ArangeDtypeFloatModule_basic",
|
"ArangeDtypeFloatModule_basic",
|
||||||
"ArangeDtypeIntModule_basic",
|
"ArangeDtypeIntModule_basic",
|
||||||
"ArangeFalsePinMemoryModule_basic",
|
"ArangeFalsePinMemoryModule_basic",
|
||||||
|
@ -103,6 +105,7 @@ MHLO_PASS_SET = {
|
||||||
"ArangeStartStepFloatModule_basic",
|
"ArangeStartStepFloatModule_basic",
|
||||||
"ArangeStartStepIntModule_basic",
|
"ArangeStartStepIntModule_basic",
|
||||||
"ArangeZeroElementOutputModule_basic",
|
"ArangeZeroElementOutputModule_basic",
|
||||||
|
"BatchMlpLayerModule_basic",
|
||||||
"BmmModule_basic",
|
"BmmModule_basic",
|
||||||
"BroadcastToModule_basic",
|
"BroadcastToModule_basic",
|
||||||
"BroadcastToSameRankStaticModule_basic",
|
"BroadcastToSameRankStaticModule_basic",
|
||||||
|
@ -124,12 +127,15 @@ MHLO_PASS_SET = {
|
||||||
"ElementwiseClampMinModule_basic",
|
"ElementwiseClampMinModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
"ElementwiseExpModule_basic",
|
"ElementwiseExpModule_basic",
|
||||||
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
|
"ElementwiseLeakyReluModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
"ElementwiseNegModule_basic",
|
"ElementwiseNegModule_basic",
|
||||||
"ElementwiseRsqrtModule_basic",
|
"ElementwiseRsqrtModule_basic",
|
||||||
"ElementwiseSigmoidModule_basic",
|
"ElementwiseSigmoidModule_basic",
|
||||||
"ElementwiseSqrtModule_basic",
|
"ElementwiseSqrtModule_basic",
|
||||||
"ElementwiseUnaryModule_basic",
|
"ElementwiseUnaryModule_basic",
|
||||||
|
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||||
"ElementwiseAddModule_basic",
|
"ElementwiseAddModule_basic",
|
||||||
|
@ -198,6 +204,8 @@ MHLO_PASS_SET = {
|
||||||
"Gather2DInputModdule_basic",
|
"Gather2DInputModdule_basic",
|
||||||
"GatherRandomIndexModule_basic",
|
"GatherRandomIndexModule_basic",
|
||||||
"GeluBackwardModule_basic",
|
"GeluBackwardModule_basic",
|
||||||
|
"HardswishModule_basic",
|
||||||
|
"HardswishRandomModule_basic",
|
||||||
"HardTanhIntModule_basic",
|
"HardTanhIntModule_basic",
|
||||||
"HardTanhModule_basic",
|
"HardTanhModule_basic",
|
||||||
"HardsigmoidModule_basic",
|
"HardsigmoidModule_basic",
|
||||||
|
@ -220,6 +228,8 @@ MHLO_PASS_SET = {
|
||||||
"MeanDynamicSizesModule_basic",
|
"MeanDynamicSizesModule_basic",
|
||||||
"MeanLargeInputModule_basic",
|
"MeanLargeInputModule_basic",
|
||||||
"MeanModule_basic",
|
"MeanModule_basic",
|
||||||
|
"Mlp1LayerModule_basic",
|
||||||
|
"Mlp2LayerModule_basic",
|
||||||
"MmTanhModule_basic",
|
"MmTanhModule_basic",
|
||||||
"Mv_basic",
|
"Mv_basic",
|
||||||
"NativeLayerNormModule4D_basic",
|
"NativeLayerNormModule4D_basic",
|
||||||
|
@ -251,6 +261,8 @@ MHLO_PASS_SET = {
|
||||||
"LiftFreshCopyModule_basic",
|
"LiftFreshCopyModule_basic",
|
||||||
"Mlp2LayerModuleNoBias_basic",
|
"Mlp2LayerModuleNoBias_basic",
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
|
"SiluModule_basic",
|
||||||
|
"SquareModule_basic",
|
||||||
"SqueezeModule_allUnitDim",
|
"SqueezeModule_allUnitDim",
|
||||||
"SqueezeDimModule_unitDim",
|
"SqueezeDimModule_unitDim",
|
||||||
"ViewCollapseOnesMiddleModule_basic",
|
"ViewCollapseOnesMiddleModule_basic",
|
||||||
|
@ -420,6 +432,7 @@ MHLO_PASS_SET = {
|
||||||
"UnsafeViewDynamicExpandModule_basic",
|
"UnsafeViewDynamicExpandModule_basic",
|
||||||
"AtenRoundIntModule_basic",
|
"AtenRoundIntModule_basic",
|
||||||
"TestF16Return_basic",
|
"TestF16Return_basic",
|
||||||
|
"_LogSoftmaxModuleStable_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
import torch
|
|
||||||
import torchvision.models as models
|
|
||||||
import torch_mlir
|
|
||||||
|
|
||||||
model = models.resnet18(pretrained=True)
|
|
||||||
model.eval()
|
|
||||||
data = torch.randn(2,3,200,200)
|
|
||||||
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
|
|
||||||
|
|
||||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
|
|
||||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
|
||||||
outf.write(str(module))
|
|
||||||
|
|
||||||
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
model = models.resnet18(pretrained=True)
|
||||||
|
model.eval()
|
||||||
|
data = torch.randn(2,3,200,200)
|
||||||
|
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||||
|
|
||||||
|
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
|
||||||
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
|
outf.write(str(module))
|
||||||
|
|
||||||
|
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")
|
|
@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
|
||||||
model = BertTinyWrapper()
|
model = BertTinyWrapper()
|
||||||
model.eval()
|
model.eval()
|
||||||
data = torch.randint(30522, (2, 128))
|
data = torch.randint(30522, (2, 128))
|
||||||
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
|
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||||
|
|
||||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
|
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
|
||||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
outf.write(str(module))
|
outf.write(str(module))
|
||||||
|
|
||||||
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
|
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")
|
|
@ -1,6 +1,6 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
else()
|
else()
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
|
||||||
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
|
||||||
let summary = "Convert Torch ops to MHLO ops";
|
let summary = "Convert Torch ops to Stablehlo ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
Convert Torch ops to mhlo ops.
|
Convert Torch ops to Stablehlo ops.
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
|
||||||
|
|
||||||
// Specify any options.
|
// Specify any options.
|
||||||
let options = [
|
let options = [
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
@ -16,10 +16,11 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
|
createConvertTorchToStablehloPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
@ -1,6 +1,6 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
else()
|
else()
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
|
||||||
/// TOSA backend contract.
|
/// TOSA backend contract.
|
||||||
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
|
// Do not register the stablehlo options if the stablehlo target is disabled
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
struct MhloBackendPipelineOptions
|
struct StablehloBackendPipelineOptions
|
||||||
: public PassPipelineOptions<MhloBackendPipelineOptions> {
|
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
|
||||||
Option<bool> enableStaticShape{
|
Option<bool> enableStaticShape{
|
||||||
*this, "enable-static-shape",
|
*this, "enable-static-shape",
|
||||||
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
||||||
|
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
|
||||||
llvm::cl::init(false)};
|
llvm::cl::init(false)};
|
||||||
};
|
};
|
||||||
|
|
||||||
void createTorchBackendToMhloBackendPipeline(
|
void createTorchBackendToStablehloBackendPipeline(
|
||||||
OpPassManager &pm, const MhloBackendPipelineOptions &options);
|
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createVerifyStablehloBackendContractPass();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||||
|
|
|
@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
|
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
|
||||||
let summary = "Verifies conformity to the mhlo backend contract";
|
let summary = "Verifies conformity to the stablehlo backend contract";
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
|
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
||||||
}
|
}
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||||
|
|
|
@ -3,13 +3,7 @@ add_subdirectory(Conversion)
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Dialect)
|
||||||
add_subdirectory(RefBackend)
|
add_subdirectory(RefBackend)
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRInitAll
|
set(LinkedLibs
|
||||||
InitAll.cpp
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRFuncDialect
|
MLIRFuncDialect
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
|
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
|
||||||
TorchMLIRRefBackend
|
TorchMLIRRefBackend
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
|
list(APPEND LinkedLibs
|
||||||
|
MhloPasses
|
||||||
|
MhloToLinalg
|
||||||
|
StablehloToMhlo
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_mlir_library(TorchMLIRInitAll
|
||||||
|
InitAll.cpp
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
${LinkedLibs}
|
||||||
|
)
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRInitAll)
|
torch_mlir_target_includes(TorchMLIRInitAll)
|
||||||
|
|
|
@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToArith)
|
add_subdirectory(TorchToArith)
|
||||||
add_subdirectory(TorchToTosa)
|
add_subdirectory(TorchToTosa)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
add_subdirectory(TorchToMhlo)
|
add_subdirectory(TorchToStablehlo)
|
||||||
endif()
|
endif()
|
||||||
add_subdirectory(TorchToTMTensor)
|
add_subdirectory(TorchToTMTensor)
|
||||||
add_subdirectory(TorchConversionToMLProgram)
|
add_subdirectory(TorchConversionToMLProgram)
|
||||||
|
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchToTMTensor
|
TorchMLIRTorchToTMTensor
|
||||||
TorchMLIRTorchConversionToMLProgram
|
TorchMLIRTorchConversionToMLProgram
|
||||||
TorchMLIRConversionUtils)
|
TorchMLIRConversionUtils)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
list(APPEND linked_libs
|
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
|
||||||
MhloPasses
|
|
||||||
TorchMLIRTorchToMhlo)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRConversionPasses
|
add_mlir_library(TorchMLIRConversionPasses
|
||||||
|
|
|
@ -9,15 +9,15 @@
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/Passes.h"
|
#include "torch-mlir/Conversion/Passes.h"
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "mhlo/transforms/passes.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
#include "transforms/passes.h"
|
#include "transforms/passes.h"
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
|
|
||||||
|
@ -32,12 +32,4 @@ namespace {
|
||||||
|
|
||||||
void mlir::torch::registerConversionPasses() {
|
void mlir::torch::registerConversionPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
|
||||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
|
||||||
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
|
||||||
});
|
|
||||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
|
||||||
return mlir::mhlo::createSymbolicShapeOptimizationPass();
|
|
||||||
});
|
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|
||||||
TorchToMhlo.cpp
|
|
||||||
MhloLegalizeUtils.cpp
|
|
||||||
Basic.cpp
|
|
||||||
Gather.cpp
|
|
||||||
Linear.cpp
|
|
||||||
ViewLike.cpp
|
|
||||||
Reduction.cpp
|
|
||||||
Pooling.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
MhloDialect
|
|
||||||
MhloToLinalg
|
|
||||||
MLIRMhloPassIncGen
|
|
||||||
LMHLOTransformsPassIncGen
|
|
||||||
TorchMLIRConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRPass
|
|
||||||
MhloDialect
|
|
||||||
MhloToLinalg
|
|
||||||
MLIRBufferTransforms
|
|
||||||
StablehloOps
|
|
||||||
TorchMLIRTorchDialect
|
|
||||||
TorchMLIRConversionUtils
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_mlir_target_includes(TorchMLIRTorchToMhlo)
|
|
|
@ -1,74 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
||||||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
||||||
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace torch {
|
|
||||||
namespace torch_to_mhlo {
|
|
||||||
|
|
||||||
struct TorchToMhloOptions {
|
|
||||||
bool enableStaticShape = false;
|
|
||||||
size_t dimSizeIndexBits = 64;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename AtenOpT>
|
|
||||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
|
||||||
public:
|
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
||||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
|
||||||
const TorchToMhloOptions &options)
|
|
||||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
|
||||||
this->options = options;
|
|
||||||
}
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
|
||||||
}
|
|
||||||
const TorchToMhloOptions &getOptions() const { return options; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
TorchToMhloOptions options;
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
|
|
||||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns,
|
|
||||||
ConversionTarget &target,
|
|
||||||
const TorchToMhloOptions &options);
|
|
||||||
|
|
||||||
} // namespace torch_to_mhlo
|
|
||||||
} // namespace torch
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
|
@ -7,15 +7,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -29,7 +30,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||||
mlir::Value &self, mlir::Value &other,
|
mlir::Value &self, mlir::Value &other,
|
||||||
|
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||||
if (selfRank > otherRank) {
|
if (selfRank > otherRank) {
|
||||||
auto unsqueezeDims =
|
auto unsqueezeDims =
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
||||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
|
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
|
||||||
unsqueezeDims, dimSizeIndexBits);
|
unsqueezeDims, dimSizeIndexBits);
|
||||||
if (failed(unsqueezeInfo))
|
if (failed(unsqueezeInfo))
|
||||||
return failure();
|
return failure();
|
||||||
other = *unsqueezeInfo;
|
other = *unsqueezeInfo;
|
||||||
} else if (otherRank > selfRank) {
|
} else if (otherRank > selfRank) {
|
||||||
auto unsqueezeDims =
|
auto unsqueezeDims =
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
||||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
|
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
|
||||||
unsqueezeDims, dimSizeIndexBits);
|
dimSizeIndexBits);
|
||||||
if (failed(unsqueezeInfo))
|
if (failed(unsqueezeInfo))
|
||||||
return failure();
|
return failure();
|
||||||
self = *unsqueezeInfo;
|
self = *unsqueezeInfo;
|
||||||
|
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
||||||
constType,
|
constType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/false));
|
/*negative=*/false));
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
return rewriter
|
||||||
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
if (elementType.isa<mlir::IntegerType>()) {
|
if (elementType.isa<mlir::IntegerType>()) {
|
||||||
|
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
||||||
constAttr = SplatElementsAttr::get(
|
constAttr = SplatElementsAttr::get(
|
||||||
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
||||||
}
|
}
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
return rewriter
|
||||||
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
||||||
constType,
|
constType,
|
||||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/true));
|
/*negative=*/true));
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
return rewriter
|
||||||
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
if (elementType.isa<mlir::IntegerType>()) {
|
if (elementType.isa<mlir::IntegerType>()) {
|
||||||
|
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
||||||
constAttr = SplatElementsAttr::get(
|
constAttr = SplatElementsAttr::get(
|
||||||
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
||||||
}
|
}
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
return rewriter
|
||||||
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
||||||
|
|
||||||
// These legalizations are for unary ops.
|
// These legalizations are for unary ops.
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, typename MhloOpT>
|
template <typename AtenOpT, typename StablehloOpT>
|
||||||
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
@ -137,13 +142,13 @@ public:
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfType = self.getType().cast<TensorType>();
|
auto selfType = self.getType().cast<TensorType>();
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<TensorType>();
|
.template cast<TensorType>();
|
||||||
self = mhlo::promoteType(rewriter, self, outType);
|
self = hlo::promoteType(rewriter, self, outType);
|
||||||
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -152,7 +157,7 @@ public:
|
||||||
// These legalizations are for unary ops with only for floating point datatypes.
|
// These legalizations are for unary ops with only for floating point datatypes.
|
||||||
// There is no supported quantized integer mode for these.
|
// There is no supported quantized integer mode for these.
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, typename MhloOpT>
|
template <typename AtenOpT, typename StablehloOpT>
|
||||||
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
@ -164,10 +169,10 @@ public:
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = self.getType().cast<TensorType>();
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
rewriter.replaceOpWithNewOp<MhloOpT>(
|
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()),
|
||||||
|
@ -198,7 +203,7 @@ public:
|
||||||
.template dyn_cast<TensorType>();
|
.template dyn_cast<TensorType>();
|
||||||
|
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat())
|
if (!outElemTy.isIntOrFloat())
|
||||||
|
@ -216,9 +221,9 @@ public:
|
||||||
|
|
||||||
SmallVector<int32_t> values(size, fillVal);
|
SmallVector<int32_t> values(size, fillVal);
|
||||||
auto constOp =
|
auto constOp =
|
||||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -247,8 +252,8 @@ public:
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<TensorType>();
|
.template cast<TensorType>();
|
||||||
|
|
||||||
lhs = mhlo::promoteType(rewriter, lhs, outTy);
|
lhs = hlo::promoteType(rewriter, lhs, outTy);
|
||||||
rhs = mhlo::promoteType(rewriter, rhs, outTy);
|
rhs = hlo::promoteType(rewriter, rhs, outTy);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||||
/*broadcast_attr*/ nullptr);
|
/*broadcast_attr*/ nullptr);
|
||||||
|
@ -274,7 +279,7 @@ public:
|
||||||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
|
@ -287,18 +292,19 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rhsType) {
|
if (!rhsType) {
|
||||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
|
outElemTy);
|
||||||
if (isa<AtenRsubScalarOp>(op)) {
|
if (isa<AtenRsubScalarOp>(op)) {
|
||||||
std::swap(lhs, rhs);
|
std::swap(lhs, rhs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||||
|
|
||||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||||
Value alpha =
|
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
|
adaptor.getAlpha(), outElemTy);
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||||
bcastDimensions);
|
bcastDimensions);
|
||||||
|
@ -328,7 +334,7 @@ public:
|
||||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
|
@ -343,11 +349,12 @@ public:
|
||||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||||
rhs = lhs;
|
rhs = lhs;
|
||||||
} else if (!rhsType) {
|
} else if (!rhsType) {
|
||||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
|
outElemTy);
|
||||||
}
|
}
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
@ -368,15 +375,15 @@ public:
|
||||||
if (roundingMode == "trunc") {
|
if (roundingMode == "trunc") {
|
||||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
// to C-style integer division.
|
// to C-style integer division.
|
||||||
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||||
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
|
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
|
||||||
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
|
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
||||||
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
|
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
|
||||||
}
|
}
|
||||||
if (roundingMode == "floor") {
|
if (roundingMode == "floor") {
|
||||||
// "floor" - rounds the results of the division down. Equivalent to
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
// floor division in Python (the // operator)
|
// floor division in Python (the // operator)
|
||||||
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
|
@ -401,7 +408,7 @@ public:
|
||||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return op.emitError("only Tensor types supported in MHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
|
@ -414,11 +421,12 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||||
|
lhsElemTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: what is the PyTorch default type promotion?
|
// TODO: what is the PyTorch default type promotion?
|
||||||
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
|
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
|
||||||
|
|
||||||
chlo::ComparisonTypeAttr compareTypeAttr;
|
chlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||||
|
@ -485,8 +493,8 @@ public:
|
||||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<TensorType>();
|
.template cast<TensorType>();
|
||||||
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||||
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
|
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||||
|
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||||
|
@ -537,8 +545,8 @@ public:
|
||||||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
permValues);
|
permValues);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||||
permutation);
|
permutation);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
||||||
} else {
|
} else {
|
||||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||||
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
|
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
|
||||||
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
|
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
|
||||||
|
inputRank);
|
||||||
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
|
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
|
||||||
rewriter.getIndexType(), dim);
|
rewriter.getIndexType(), dim);
|
||||||
}
|
}
|
||||||
|
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
AtenWhereSelfOp op,
|
AtenWhereSelfOp op, OpAdaptor adaptor,
|
||||||
OpAdaptor adaptor,
|
ConversionPatternRewriter &rewriter) const {
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
Value cond = adaptor.getCondition();
|
Value cond = adaptor.getCondition();
|
||||||
Value other = adaptor.getOther();
|
Value other = adaptor.getOther();
|
||||||
|
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
return op.emitError("failed broadcast other and condition ranks");
|
return op.emitError("failed broadcast other and condition ranks");
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
||||||
op,
|
op, getTypeConverter()->convertType(op.getType()),
|
||||||
getTypeConverter()->convertType(op.getType()),
|
|
||||||
ArrayRef<Value>{cond, self, other});
|
ArrayRef<Value>{cond, self, other});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||||
rewriter.replaceOp(op, bcastOp);
|
rewriter.replaceOp(op, bcastOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
op->getLoc(), ValueRange{bcastShapeVec});
|
op->getLoc(), ValueRange{bcastShapeVec});
|
||||||
auto dimensionNumbers =
|
auto dimensionNumbers =
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
op, outType, self, bcastShapeTensor,
|
op, outType, self, bcastShapeTensor,
|
||||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||||
}
|
}
|
||||||
|
@ -708,8 +715,8 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
permValues);
|
permValues);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||||
permutation);
|
permutation);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -721,7 +728,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = self.getType().cast<TensorType>();
|
||||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
|
@ -751,16 +758,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
||||||
return APInt(bitWidth, v.getSExtValue());
|
return APInt(bitWidth, v.getSExtValue());
|
||||||
});
|
});
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
|
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||||
|
valueAttr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
|
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||||
adaptor.getValue());
|
adaptor.getValue());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenReciprocalOp
|
// AtenReciprocalOp
|
||||||
// Reciprocal(x) = Div(1, x)
|
// Reciprocal(x) = Div(1, x)
|
||||||
template <>
|
template <>
|
||||||
|
@ -777,7 +784,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -790,9 +797,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto outputElemType = outputType.getElementType();
|
auto outputElemType = outputType.getElementType();
|
||||||
Value mhloTensor =
|
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
||||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
|
rewriter, op, adaptor.getA(), outputElemType);
|
||||||
rewriter.replaceOp(op, mhloTensor);
|
rewriter.replaceOp(op, stablehloTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -815,7 +822,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenReluOp
|
// AtenReluOp
|
||||||
// Relu(x) = Max(0, x)
|
// Relu(x) = Max(0, x)
|
||||||
template <>
|
template <>
|
||||||
|
@ -836,11 +842,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
false),
|
false),
|
||||||
lhs);
|
lhs);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
|
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Convert a Aten::GELU to HLO
|
// Convert a Aten::GELU to HLO
|
||||||
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
||||||
template <>
|
template <>
|
||||||
|
@ -857,12 +862,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||||
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
|
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
|
||||||
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
|
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
|
||||||
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
|
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
|
||||||
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
|
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
||||||
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
||||||
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
|
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
|
||||||
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
|
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -881,7 +886,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenBatchNormOp
|
// AtenBatchNormOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
|
@ -919,28 +923,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||||
op->getLoc(), ValueRange{channelDim});
|
op->getLoc(), ValueRange{channelDim});
|
||||||
if (failed(checkNotNone(rewriter, op, weight))) {
|
if (failed(checkNotNone(rewriter, op, weight))) {
|
||||||
weight = mhlo::getConstantOfShape(
|
weight = hlo::getConstantOfShape(
|
||||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||||
channelShape,
|
channelShape,
|
||||||
RankedTensorType::get({inputTy.getShape()[1]},
|
RankedTensorType::get({inputTy.getShape()[1]},
|
||||||
inputTy.getElementType()));
|
inputTy.getElementType()));
|
||||||
}
|
}
|
||||||
if (failed(checkNotNone(rewriter, op, bias))) {
|
if (failed(checkNotNone(rewriter, op, bias))) {
|
||||||
bias = mhlo::getConstantOfShape(
|
bias = hlo::getConstantOfShape(
|
||||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||||
channelShape,
|
channelShape,
|
||||||
RankedTensorType::get({inputTy.getShape()[1]},
|
RankedTensorType::get({inputTy.getShape()[1]},
|
||||||
inputTy.getElementType()));
|
inputTy.getElementType()));
|
||||||
}
|
}
|
||||||
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
||||||
runningVar = mhlo::getConstantOfShape(
|
runningVar = hlo::getConstantOfShape(
|
||||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||||
channelShape,
|
channelShape,
|
||||||
RankedTensorType::get({inputTy.getShape()[1]},
|
RankedTensorType::get({inputTy.getShape()[1]},
|
||||||
inputTy.getElementType()));
|
inputTy.getElementType()));
|
||||||
}
|
}
|
||||||
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
||||||
runningMean = mhlo::getConstantOfShape(
|
runningMean = hlo::getConstantOfShape(
|
||||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||||
channelShape,
|
channelShape,
|
||||||
RankedTensorType::get({inputTy.getShape()[1]},
|
RankedTensorType::get({inputTy.getShape()[1]},
|
||||||
|
@ -983,10 +987,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||||
Type batchMeanOrVarTy =
|
Type batchMeanOrVarTy =
|
||||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
||||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
auto batchNormTrainingResult =
|
||||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||||
rewriter.getI64IntegerAttr(1));
|
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||||
|
rewriter.getI64IntegerAttr(1));
|
||||||
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
|
@ -995,10 +1000,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
inputTy.getShape().end()};
|
inputTy.getShape().end()};
|
||||||
castShape[1] = weightTy.getShape()[0];
|
castShape[1] = weightTy.getShape()[0];
|
||||||
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
||||||
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
|
// Feature counts must match among operands of
|
||||||
|
// stablehlo::BatchNormInferenceOp.
|
||||||
Value inputCasted =
|
Value inputCasted =
|
||||||
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
||||||
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
|
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||||
runningMean, runningVar,
|
runningMean, runningVar,
|
||||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||||
|
@ -1008,7 +1014,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenNativeLayerNormOp
|
// AtenNativeLayerNormOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
|
@ -1076,21 +1081,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
|
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
|
||||||
numEmbeddingDimSize};
|
numEmbeddingDimSize};
|
||||||
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
|
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
|
||||||
|
|
||||||
auto mhloBatchNormOutTy =
|
auto stablehloBatchNormOutTy =
|
||||||
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
|
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
|
||||||
auto mhloBathNormOutMeanOrVarTy =
|
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
|
||||||
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
|
meanOrVarStablehloOutShape, inputTy.getElementType());
|
||||||
|
|
||||||
// Reshape input
|
// Reshape input
|
||||||
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), mhloBatchNormOutTy, input,
|
op->getLoc(), stablehloBatchNormOutTy, input,
|
||||||
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
||||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||||
.value());
|
.value());
|
||||||
|
|
||||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
||||||
SmallVector<APFloat> zeroConstVec(
|
SmallVector<APFloat> zeroConstVec(
|
||||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
||||||
.cast<mlir::FloatType>()
|
.cast<mlir::FloatType>()
|
||||||
|
@ -1103,16 +1108,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
auto oneOrZeroConstType =
|
auto oneOrZeroConstType =
|
||||||
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
||||||
|
|
||||||
Value scale = rewriter.create<mhlo::ConstantOp>(
|
Value scale = rewriter.create<stablehlo::ConstantOp>(
|
||||||
op->getLoc(), oneOrZeroConstType,
|
op->getLoc(), oneOrZeroConstType,
|
||||||
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
|
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
|
||||||
Value offset = rewriter.create<mhlo::ConstantOp>(
|
Value offset = rewriter.create<stablehlo::ConstantOp>(
|
||||||
op->getLoc(), oneOrZeroConstType,
|
op->getLoc(), oneOrZeroConstType,
|
||||||
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
|
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
|
||||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
auto batchNormTrainingResult =
|
||||||
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||||
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
|
op->getLoc(), stablehloBatchNormOutTy,
|
||||||
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
|
||||||
|
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
|
||||||
|
rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
// Reshape back
|
// Reshape back
|
||||||
auto outputTy =
|
auto outputTy =
|
||||||
|
@ -1120,36 +1127,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
auto outputMeanOrVarTy =
|
auto outputMeanOrVarTy =
|
||||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||||
|
|
||||||
auto output = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||||
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||||
.value());
|
.value());
|
||||||
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
||||||
mhlo::getConstTensor(
|
hlo::getConstTensor(
|
||||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||||
.value());
|
.value());
|
||||||
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
||||||
mhlo::getConstTensor(
|
hlo::getConstTensor(
|
||||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||||
.value());
|
.value());
|
||||||
|
|
||||||
// Apply affine transform: output x weight + bias [element-wise]
|
// Apply affine transform: output x weight + bias [element-wise]
|
||||||
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||||
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||||
auto outputMulWeight =
|
auto outputMulWeight =
|
||||||
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||||
auto finalOuput =
|
auto finalOuput = rewriter.create<stablehlo::AddOp>(
|
||||||
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
|
op->getLoc(), outputMulWeight, bcastedBias);
|
||||||
rewriter.replaceOp(op, {finalOuput, mean, var});
|
rewriter.replaceOp(op, {finalOuput, mean, var});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenCatOp
|
// AtenCatOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
|
@ -1173,11 +1179,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Promote type
|
// Promote type
|
||||||
for (auto &v : builtinTensors) {
|
for (auto &v : builtinTensors) {
|
||||||
v = mhlo::promoteType(rewriter, v, outType);
|
v = hlo::promoteType(rewriter, v, outType);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||||
op, outType, ValueRange(builtinTensors), posDim);
|
op, outType, ValueRange(builtinTensors), posDim);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1225,7 +1231,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "this op should be folded as its `min` and `max` both are none");
|
op, "this op should be folded as its `min` and `max` both are none");
|
||||||
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
||||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
maxValue =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||||
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
||||||
if (failed(minInfo)) {
|
if (failed(minInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1233,7 +1240,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
minValue = *minInfo;
|
minValue = *minInfo;
|
||||||
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
||||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
minValue =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||||
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
||||||
if (failed(maxInfo)) {
|
if (failed(maxInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1241,10 +1249,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
maxValue = *maxInfo;
|
maxValue = *maxInfo;
|
||||||
} else {
|
} else {
|
||||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
minValue =
|
||||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||||
|
maxValue =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
|
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
||||||
|
maxValue);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1266,24 +1277,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||||
op, "unimplemented: only int or float dtype supported");
|
op, "unimplemented: only int or float dtype supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
|
Value start =
|
||||||
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
|
||||||
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
|
Value end =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
||||||
|
Value step =
|
||||||
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
|
||||||
|
|
||||||
// Get length of the 1-d output tensor
|
// Get length of the 1-d output tensor
|
||||||
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
|
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
|
||||||
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
|
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
|
||||||
|
|
||||||
Value resultLength = rewriter.create<mhlo::ReshapeOp>(
|
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
||||||
loc, RankedTensorType::get({1}, dtype), divOut);
|
loc, RankedTensorType::get({1}, dtype), divOut);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
|
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
|
||||||
resultLength = rewriter.create<mhlo::ConvertOp>(
|
resultLength = rewriter.create<stablehlo::ConvertOp>(
|
||||||
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value window =
|
Value window =
|
||||||
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||||
DenseIntElementsAttr broadcastDimensions;
|
DenseIntElementsAttr broadcastDimensions;
|
||||||
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
||||||
broadcastDimensions);
|
broadcastDimensions);
|
||||||
|
@ -1298,9 +1312,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto outType = this->getTypeConverter()
|
auto outType =
|
||||||
->convertType(op.getType())
|
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||||
.cast<TensorType>();
|
|
||||||
if (!outType) {
|
if (!outType) {
|
||||||
return op.emitError("only tensor type is supported");
|
return op.emitError("only tensor type is supported");
|
||||||
}
|
}
|
||||||
|
@ -1320,26 +1333,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
||||||
|
|
||||||
// Compute
|
// Compute
|
||||||
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
Value kBeta0 =
|
||||||
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
|
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||||
Value erfArg =
|
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
|
||||||
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
|
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
|
||||||
|
adaptor.getSelf());
|
||||||
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
||||||
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
|
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
|
||||||
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
|
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
|
||||||
Value inputSquared = rewriter.create<mhlo::MulOp>(
|
Value inputSquared = rewriter.create<stablehlo::MulOp>(
|
||||||
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
||||||
Value negHalfInputSquared =
|
Value negHalfInputSquared =
|
||||||
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
|
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||||
Value expRes =
|
Value expRes =
|
||||||
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
|
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||||
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
|
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
|
||||||
Value pdfTimesInput =
|
Value pdfTimesInput =
|
||||||
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||||
Value pdfTimesInputAddCdf =
|
Value pdfTimesInputAddCdf =
|
||||||
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
|
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
|
||||||
pdfTimesInputAddCdf);
|
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1366,9 +1380,9 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
target.addIllegalOp<AtenTransposeIntOp>();
|
target.addIllegalOp<AtenTransposeIntOp>();
|
||||||
|
@ -1376,23 +1390,24 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
target.addIllegalOp<RuntimeAssertOp>();
|
target.addIllegalOp<RuntimeAssertOp>();
|
||||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||||
|
|
||||||
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
|
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
|
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
|
||||||
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
|
||||||
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
|
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
|
||||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
|
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
|
||||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
|
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
|
||||||
#undef INSERT_UNARY_PATTERN
|
#undef INSERT_UNARY_PATTERN
|
||||||
|
|
||||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
|
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
context)
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||||
|
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||||
|
|
||||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||||
|
@ -1482,10 +1497,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
|
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
|
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
|
||||||
context)
|
typeConverter, context)
|
||||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
||||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
||||||
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
|
@ -0,0 +1,29 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||||
|
TorchToStablehlo.cpp
|
||||||
|
StablehloLegalizeUtils.cpp
|
||||||
|
Basic.cpp
|
||||||
|
Gather.cpp
|
||||||
|
Linear.cpp
|
||||||
|
ViewLike.cpp
|
||||||
|
Reduction.cpp
|
||||||
|
Pooling.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRBufferTransforms
|
||||||
|
StablehloOps
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
TorchMLIRConversionUtils
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)
|
|
@ -7,14 +7,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -24,7 +25,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
|
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
||||||
// indexVecDim
|
// indexVecDim
|
||||||
int64_t indexVecDim = indicesRank;
|
int64_t indexVecDim = indicesRank;
|
||||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
/*offsetDims=*/offsetDims,
|
/*offsetDims=*/offsetDims,
|
||||||
/*collapsedSliceDims=*/collapsedSliceDims,
|
/*collapsedSliceDims=*/collapsedSliceDims,
|
||||||
|
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
auto outputTy =
|
auto outputTy =
|
||||||
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||||||
sliceSizesTensor, dimsAttr)
|
sliceSizesTensor, dimsAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
// Ref:
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||||
// padding_idx (int, optional)
|
// padding_idx (int, optional)
|
||||||
// – If specified, the entries at padding_idx do not contribute to the gradient;
|
// – If specified, the entries at padding_idx do not contribute to the
|
||||||
// therefore, the embedding vector at padding_idx is not updated during training,
|
// gradient; therefore, the embedding vector at padding_idx is not updated
|
||||||
// i.e. it remains as a fixed “pad”.
|
// during training, i.e. it remains as a fixed “pad”.
|
||||||
// scale_grad_by_freq (boolean, optional)
|
// scale_grad_by_freq (boolean, optional)
|
||||||
// – If given, this will scale gradients by the inverse of frequency of the
|
// – If given, this will scale gradients by the inverse of frequency of the
|
||||||
// words in the mini-batch. Default False.
|
// words in the mini-batch. Default False.
|
||||||
|
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value output = gatherTensorAlongSingleAxis(
|
Value output = gatherTensorAlongSingleAxis(
|
||||||
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), output);
|
op, getTypeConverter()->convertType(op.getType()), output);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
Value output = gatherTensorAlongSingleAxis(
|
Value output = gatherTensorAlongSingleAxis(
|
||||||
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), output);
|
op, getTypeConverter()->convertType(op.getType()), output);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto options = getOptions();
|
auto options = getOptions();
|
||||||
auto indexShapeInfo =
|
auto indexShapeInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||||
if (failed(indexShapeInfo)) {
|
if (failed(indexShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dim sizes of `index` param");
|
op, "failed to get dim sizes of `index` param");
|
||||||
|
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
SmallVector<Value> toConcat;
|
SmallVector<Value> toConcat;
|
||||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||||
if (i == dim) {
|
if (i == dim) {
|
||||||
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
loc, toConcatIndexType, index, toConcatIndexShape));
|
loc, toConcatIndexType, index, toConcatIndexShape));
|
||||||
} else {
|
} else {
|
||||||
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
|
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
||||||
loc, toConcatIndexType, toConcatIndexShape,
|
loc, toConcatIndexType, toConcatIndexShape,
|
||||||
rewriter.getI64IntegerAttr(i)));
|
rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
|
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
||||||
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
||||||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
||||||
|
|
||||||
|
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
startIndexMap.push_back(i);
|
startIndexMap.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
/*offsetDims=*/{},
|
/*offsetDims=*/{},
|
||||||
/*collapsedSliceDims=*/collapsedDims,
|
/*collapsedSliceDims=*/collapsedDims,
|
||||||
/*startIndexMap=*/startIndexMap,
|
/*startIndexMap=*/startIndexMap,
|
||||||
/*indexVecDim=*/indexVecDim);
|
/*indexVecDim=*/indexVecDim);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||||
op, input, gatherIndicies, dimsAttr,
|
op, input, gatherIndicies, dimsAttr,
|
||||||
rewriter.getI64TensorAttr(sliceSizes));
|
rewriter.getI64TensorAttr(sliceSizes));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -25,7 +26,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
|
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
ArrayRef<int64_t> broadcastDims) {
|
ArrayRef<int64_t> broadcastDims) {
|
||||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||||
|
|
||||||
RankedTensorType outTy =
|
RankedTensorType outTy =
|
||||||
RankedTensorType::get(shape, tensorTy.getElementType());
|
RankedTensorType::get(shape, tensorTy.getElementType());
|
||||||
|
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
rewriter.getIntegerType(64));
|
rewriter.getIntegerType(64));
|
||||||
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
||||||
|
|
||||||
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
loc, outTy, tensor, mhloShape, broadcastAttr);
|
loc, outTy, tensor, stablehloShape, broadcastAttr);
|
||||||
return broadcast;
|
return broadcast;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||||
ArrayRef<int64_t> inpTransDims) {
|
ArrayRef<int64_t> inpTransDims) {
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
auto rank = inputTy.getRank();
|
auto rank = inputTy.getRank();
|
||||||
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
|
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
||||||
auto inpShape = inputTy.getShape();
|
auto inpShape = inputTy.getShape();
|
||||||
std::vector<int64_t> newShape;
|
std::vector<int64_t> newShape;
|
||||||
newShape.reserve(rank);
|
newShape.reserve(rank);
|
||||||
|
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||||
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
|
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
|
||||||
|
|
||||||
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
|
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
|
||||||
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
|
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
|
||||||
permuteAttr);
|
input, permuteAttr);
|
||||||
return result.getResult();
|
return result.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// set result dimensions
|
// set result dimensions
|
||||||
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
|
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
|
||||||
|
lhsResultDim >= 0) {
|
||||||
outShape.push_back(lhsShape[lhsResultDim]);
|
outShape.push_back(lhsShape[lhsResultDim]);
|
||||||
}
|
}
|
||||||
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
|
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
|
||||||
|
rhsResultDim >= 0) {
|
||||||
outShape.push_back(rhsShape[rhsResultDim]);
|
outShape.push_back(rhsShape[rhsResultDim]);
|
||||||
}
|
}
|
||||||
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
||||||
|
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
std::vector<int64_t> newShape(rhsShape.begin(),
|
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||||
rhsShape.begin() + leadingRank);
|
rhsShape.begin() + leadingRank);
|
||||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
|
||||||
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
|
dimSizeIndexBits);
|
||||||
auto lhsDimSizes =
|
auto lhsDimSizes =
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||||
lhsDimSizes.end());
|
lhsDimSizes.end());
|
||||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||||
|
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
std::vector<int64_t> newShape(lhsShape.begin(),
|
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||||
lhsShape.begin() + leadingRank);
|
lhsShape.begin() + leadingRank);
|
||||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
|
||||||
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
|
dimSizeIndexBits);
|
||||||
auto rhsDimSizes =
|
auto rhsDimSizes =
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||||
rhsDimSizes.end());
|
rhsDimSizes.end());
|
||||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||||
|
@ -218,8 +221,8 @@ public:
|
||||||
if (lhsRank <= 2 && rhsRank <= 2) {
|
if (lhsRank <= 2 && rhsRank <= 2) {
|
||||||
auto tensorType =
|
auto tensorType =
|
||||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||||
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
|
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
|
||||||
nullptr);
|
rhs, nullptr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,8 +256,8 @@ public:
|
||||||
lhsContractingDim = nBatchDims;
|
lhsContractingDim = nBatchDims;
|
||||||
}
|
}
|
||||||
|
|
||||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||||
mhlo::DotDimensionNumbersAttr::get(
|
stablehlo::DotDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
/*lhsBatchingDimensions=*/batchDims,
|
/*lhsBatchingDimensions=*/batchDims,
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
/*rhsBatchingDimensions=*/batchDims,
|
||||||
|
@ -264,8 +267,8 @@ public:
|
||||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||||
lhsContractingDim, rhsContractingDim);
|
lhsContractingDim, rhsContractingDim);
|
||||||
output = rewriter
|
output = rewriter
|
||||||
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||||
dotDimensionNumbers, nullptr)
|
dotDimensionNumbers, nullptr)
|
||||||
.getResult();
|
.getResult();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -312,7 +315,7 @@ public:
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only ranked tensor types are supported in MHLO matmul");
|
"only ranked tensor types are supported in StableHLO matmul");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -335,7 +338,7 @@ public:
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only ranked tensor types are supported in MHLO matmul");
|
"only ranked tensor types are supported in StableHLO matmul");
|
||||||
|
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
@ -371,7 +374,7 @@ public:
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only ranked tensor types are supported in MHLO matmul");
|
"only ranked tensor types are supported in StableHLO matmul");
|
||||||
|
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
@ -398,10 +401,10 @@ public:
|
||||||
auto bias = adaptor.getBias();
|
auto bias = adaptor.getBias();
|
||||||
auto biasTy = bias.getType();
|
auto biasTy = bias.getType();
|
||||||
|
|
||||||
// MHLO does not mandate that elementwise op tensors need to be ranked.
|
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
||||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||||
!biasTy.template isa<RankedTensorType>())
|
!biasTy.template isa<RankedTensorType>())
|
||||||
return op.emitError("only ranked tensor types are supported in MHLO "
|
return op.emitError("only ranked tensor types are supported in StableHLO "
|
||||||
"matmul for bias tensor");
|
"matmul for bias tensor");
|
||||||
|
|
||||||
// weight.T
|
// weight.T
|
||||||
|
@ -427,14 +430,14 @@ public:
|
||||||
auto outTy =
|
auto outTy =
|
||||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||||
lhsContractingDim, rhsContractingDim);
|
lhsContractingDim, rhsContractingDim);
|
||||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||||
mhlo::DotDimensionNumbersAttr::get(
|
stablehlo::DotDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
/*lhsBatchingDimensions=*/batchDims,
|
/*lhsBatchingDimensions=*/batchDims,
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
/*rhsBatchingDimensions=*/batchDims,
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
|
@ -464,7 +467,7 @@ public:
|
||||||
auto weightElemTy = weightTy.getElementType();
|
auto weightElemTy = weightTy.getElementType();
|
||||||
auto rank = weightTy.getRank();
|
auto rank = weightTy.getRank();
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
|
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
|
||||||
rewriter, op, weight, options.dimSizeIndexBits);
|
rewriter, op, weight, options.dimSizeIndexBits);
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
SmallVector<int64_t> weightShapeInt(rank);
|
SmallVector<int64_t> weightShapeInt(rank);
|
||||||
|
@ -488,7 +491,7 @@ public:
|
||||||
}
|
}
|
||||||
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), weightShapeVec);
|
op->getLoc(), weightShapeVec);
|
||||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||||
weight, weightShapeTensor);
|
weight, weightShapeTensor);
|
||||||
|
|
||||||
|
@ -497,7 +500,7 @@ public:
|
||||||
for (int64_t i = 0; i <= rank; i++)
|
for (int64_t i = 0; i <= rank; i++)
|
||||||
transposeDims[i] = i;
|
transposeDims[i] = i;
|
||||||
std::swap(transposeDims[1], transposeDims[0]);
|
std::swap(transposeDims[1], transposeDims[0]);
|
||||||
weight = rewriter.create<mhlo::TransposeOp>(
|
weight = rewriter.create<stablehlo::TransposeOp>(
|
||||||
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
|
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
|
||||||
|
|
||||||
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
|
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
|
||||||
|
@ -509,7 +512,7 @@ public:
|
||||||
weightShapeVec[1] = OCMulGValue;
|
weightShapeVec[1] = OCMulGValue;
|
||||||
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), weightShapeVec);
|
op->getLoc(), weightShapeVec);
|
||||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||||
weight, weightShapeTensor);
|
weight, weightShapeTensor);
|
||||||
return weight;
|
return weight;
|
||||||
|
@ -544,25 +547,27 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare for transposed convolution
|
// Prepare for transposed convolution
|
||||||
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
|
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
|
||||||
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
|
DenseIntElementsAttr stablehloStride =
|
||||||
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
|
rewriter.getI64TensorAttr(stablehloStrideVec);
|
||||||
|
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
|
||||||
for (int i = 0; i < nSpatialDims; ++i) {
|
for (int i = 0; i < nSpatialDims; ++i) {
|
||||||
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
|
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
|
||||||
mhloPaddingVec[i * 2] = padInt;
|
stablehloPaddingVec[i * 2] = padInt;
|
||||||
mhloPaddingVec[i * 2 + 1] = padInt;
|
stablehloPaddingVec[i * 2 + 1] = padInt;
|
||||||
}
|
}
|
||||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
|
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
|
||||||
mhloPaddingVec);
|
stablehloPaddingVec);
|
||||||
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
|
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
|
||||||
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
|
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
|
||||||
DenseIntElementsAttr mhloLhsDilation =
|
DenseIntElementsAttr stablehloLhsDilation =
|
||||||
rewriter.getI64TensorAttr(mhloLhsDilationVec);
|
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
|
||||||
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
|
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
|
||||||
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
DenseIntElementsAttr mhloRhsDilation =
|
stablehloRhsDilationVec.begin());
|
||||||
rewriter.getI64TensorAttr(mhloRhsDilationVec);
|
DenseIntElementsAttr stablehloRhsDilation =
|
||||||
|
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
|
||||||
|
|
||||||
DenseElementsAttr windowReversal;
|
DenseElementsAttr windowReversal;
|
||||||
ArrayAttr precisionConfig;
|
ArrayAttr precisionConfig;
|
||||||
|
@ -571,8 +576,8 @@ public:
|
||||||
for (int i = 0; i < nSpatialDims; ++i) {
|
for (int i = 0; i < nSpatialDims; ++i) {
|
||||||
spatialDims.push_back(i + 2);
|
spatialDims.push_back(i + 2);
|
||||||
}
|
}
|
||||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||||
mhlo::ConvDimensionNumbersAttr::get(
|
stablehlo::ConvDimensionNumbersAttr::get(
|
||||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||||
/*inputFeatureDimension=*/1,
|
/*inputFeatureDimension=*/1,
|
||||||
/*inputSpatialDimensions=*/spatialDims,
|
/*inputSpatialDimensions=*/spatialDims,
|
||||||
|
@ -583,17 +588,18 @@ public:
|
||||||
/*outputSpatialDimensions=*/spatialDims);
|
/*outputSpatialDimensions=*/spatialDims);
|
||||||
|
|
||||||
// Reverse and transpose weight
|
// Reverse and transpose weight
|
||||||
weight = rewriter.create<mhlo::ReverseOp>(
|
weight = rewriter.create<stablehlo::ReverseOp>(
|
||||||
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
|
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
|
||||||
if (groups != 1) {
|
if (groups != 1) {
|
||||||
weight = reshapeConvWeight(rewriter, op, weight, groups);
|
weight = reshapeConvWeight(rewriter, op, weight, groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create transposed convolution
|
// Create transposed convolution
|
||||||
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||||
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
|
op->getLoc(), convOutTy, input, weight, stablehloStride,
|
||||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||||
|
precisionConfig);
|
||||||
|
|
||||||
// Handle output padding
|
// Handle output padding
|
||||||
if (!needHandleOutputPadding) {
|
if (!needHandleOutputPadding) {
|
||||||
|
@ -605,8 +611,8 @@ public:
|
||||||
std::copy(outputPadding.begin(), outputPadding.end(),
|
std::copy(outputPadding.begin(), outputPadding.end(),
|
||||||
edgePaddingHighVec.begin() + 2);
|
edgePaddingHighVec.begin() + 2);
|
||||||
Value paddingValue =
|
Value paddingValue =
|
||||||
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
||||||
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
|
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
|
||||||
mlir::DenseIntElementsAttr edgePaddingLow =
|
mlir::DenseIntElementsAttr edgePaddingLow =
|
||||||
rewriter.getI64VectorAttr(edgePaddingLowVec);
|
rewriter.getI64VectorAttr(edgePaddingLowVec);
|
||||||
mlir::DenseIntElementsAttr edgePaddingHigh =
|
mlir::DenseIntElementsAttr edgePaddingHigh =
|
||||||
|
@ -614,7 +620,7 @@ public:
|
||||||
mlir::DenseIntElementsAttr interiorPadding =
|
mlir::DenseIntElementsAttr interiorPadding =
|
||||||
rewriter.getI64VectorAttr(interiorPaddingVec);
|
rewriter.getI64VectorAttr(interiorPaddingVec);
|
||||||
|
|
||||||
auto paddedOutput = rewriter.create<mhlo::PadOp>(
|
auto paddedOutput = rewriter.create<stablehlo::PadOp>(
|
||||||
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
|
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
|
||||||
edgePaddingHigh, interiorPadding);
|
edgePaddingHigh, interiorPadding);
|
||||||
|
|
||||||
|
@ -628,22 +634,22 @@ public:
|
||||||
ArrayRef<int64_t> dilation, int64_t groups) const {
|
ArrayRef<int64_t> dilation, int64_t groups) const {
|
||||||
int64_t nDims = outType.getRank();
|
int64_t nDims = outType.getRank();
|
||||||
|
|
||||||
// Get mhlo::ConvolutionOp attributes
|
// Get stablehlo::ConvolutionOp attributes
|
||||||
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
|
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<long int>(stride.size())},
|
RankedTensorType::get({static_cast<long int>(stride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
stride);
|
stride);
|
||||||
std::vector<int64_t> mhloPaddingVec;
|
std::vector<int64_t> stablehloPaddingVec;
|
||||||
for (size_t i = 0; i < padding.size(); i++) {
|
for (size_t i = 0; i < padding.size(); i++) {
|
||||||
mhloPaddingVec.emplace_back(padding[i]);
|
stablehloPaddingVec.emplace_back(padding[i]);
|
||||||
mhloPaddingVec.emplace_back(padding[i]);
|
stablehloPaddingVec.emplace_back(padding[i]);
|
||||||
}
|
}
|
||||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPaddingVec);
|
stablehloPaddingVec);
|
||||||
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
|
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
dilation);
|
dilation);
|
||||||
|
@ -651,8 +657,8 @@ public:
|
||||||
for (int64_t i = 2; i < nDims; i++) {
|
for (int64_t i = 2; i < nDims; i++) {
|
||||||
spatialDimensions.emplace_back(i);
|
spatialDimensions.emplace_back(i);
|
||||||
}
|
}
|
||||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||||
mhlo::ConvDimensionNumbersAttr::get(
|
stablehlo::ConvDimensionNumbersAttr::get(
|
||||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||||
/*inputFeatureDimension=*/1,
|
/*inputFeatureDimension=*/1,
|
||||||
/*inputSpatialDimensions=*/spatialDimensions,
|
/*inputSpatialDimensions=*/spatialDimensions,
|
||||||
|
@ -662,17 +668,18 @@ public:
|
||||||
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
|
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
|
||||||
/*outputSpatialDimensions=*/spatialDimensions);
|
/*outputSpatialDimensions=*/spatialDimensions);
|
||||||
|
|
||||||
// mhlo::ConvolutionOp's optional attributes, leave them as default
|
// stablehlo::ConvolutionOp's optional attributes, leave them as default
|
||||||
DenseIntElementsAttr mhloLhsDilation;
|
DenseIntElementsAttr stablehloLhsDilation;
|
||||||
DenseElementsAttr windowReversal;
|
DenseElementsAttr windowReversal;
|
||||||
ArrayAttr precisionConfig;
|
ArrayAttr precisionConfig;
|
||||||
|
|
||||||
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||||
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
|
op->getLoc(), outType, input, weight, stablehloWindowStride,
|
||||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||||
|
precisionConfig);
|
||||||
|
|
||||||
return mhloConvOp.getResult();
|
return stablehloConvOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -754,21 +761,22 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value mhloConvResult;
|
Value stablehloConvResult;
|
||||||
if (transposed) {
|
if (transposed) {
|
||||||
mhloConvResult = convertTransposedConv(
|
stablehloConvResult = convertTransposedConv(
|
||||||
op, rewriter, outTy, input, weight, stride, padding, dilation,
|
op, rewriter, outTy, input, weight, stride, padding, dilation,
|
||||||
outputPadding, groups, needHandleOutputPadding);
|
outputPadding, groups, needHandleOutputPadding);
|
||||||
} else {
|
} else {
|
||||||
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
|
stablehloConvResult =
|
||||||
stride, padding, dilation, groups);
|
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
|
||||||
|
dilation, groups);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto bias = adaptor.getBias();
|
auto bias = adaptor.getBias();
|
||||||
|
|
||||||
// No bias provided
|
// No bias provided
|
||||||
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
|
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
|
||||||
rewriter.replaceOp(op, mhloConvResult);
|
rewriter.replaceOp(op, stablehloConvResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -790,21 +798,21 @@ public:
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||||
|
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||||
options.dimSizeIndexBits);
|
options.dimSizeIndexBits);
|
||||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
bias = hlo::promoteType(rewriter, bias, outTy);
|
||||||
|
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
|
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||||
bias, bcastDimensions);
|
op, outTy, stablehloConvResult, bias, bcastDimensions);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -28,7 +29,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
@ -40,14 +41,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
constType, {APFloat::getLargest(
|
constType, {APFloat::getLargest(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/true)});
|
/*negative=*/true)});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op->emitError("unimplemented lowering in AtenPoolingOp");
|
op->emitError("unimplemented lowering in AtenPoolingOp");
|
||||||
|
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||||
// input
|
// input
|
||||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
std::copy(dilation.begin(), dilation.end(),
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
mhloDilation.begin() + inputRank - 2);
|
stablehloDilation.begin() + inputRank - 2);
|
||||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - 2);
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
mhloKernelSize.begin() + inputRank - 2);
|
stablehloKernelSize.begin() + inputRank - 2);
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloKernelSize);
|
stablehloKernelSize);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloStride);
|
stablehloStride);
|
||||||
DenseIntElementsAttr baseDilations;
|
DenseIntElementsAttr baseDilations;
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloDilation);
|
stablehloDilation);
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPadding);
|
stablehloPadding);
|
||||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
baseDilations, windowDilations, pad);
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
|
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||||
// input
|
// input
|
||||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
std::copy(dilation.begin(), dilation.end(),
|
std::copy(dilation.begin(), dilation.end(),
|
||||||
mhloDilation.begin() + inputRank - 2);
|
stablehloDilation.begin() + inputRank - 2);
|
||||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - 2);
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
mhloKernelSize.begin() + inputRank - 2);
|
stablehloKernelSize.begin() + inputRank - 2);
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloKernelSize);
|
stablehloKernelSize);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloStride);
|
stablehloStride);
|
||||||
DenseIntElementsAttr baseDilations;
|
DenseIntElementsAttr baseDilations;
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloDilation);
|
stablehloDilation);
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPadding);
|
stablehloPadding);
|
||||||
|
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
auto inputShapeInfo =
|
auto inputShapeInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto initIndexTensor =
|
auto initIndexTensor =
|
||||||
rewriter
|
rewriter
|
||||||
.create<mhlo::DynamicIotaOp>(
|
.create<stablehlo::DynamicIotaOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(initIndexShapeForType,
|
RankedTensorType::get(initIndexShapeForType,
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
|
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto indexTensor =
|
auto indexTensor =
|
||||||
rewriter
|
rewriter
|
||||||
.create<mhlo::DynamicReshapeOp>(
|
.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||||
initIndexTensor, inputShapeTensor)
|
initIndexTensor, inputShapeTensor)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
|
|
||||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||||
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||||
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||||
|
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
auto *secondValArg = std::next(firstIdxArg);
|
auto *secondValArg = std::next(firstIdxArg);
|
||||||
auto *secondIdxArg = std::next(secondValArg);
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
mhlo::ComparisonDirection::GE);
|
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
mhlo::ComparisonDirection::EQ);
|
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
compareGeDirectionAttr, compareTypeAttr);
|
compareGeDirectionAttr, compareTypeAttr);
|
||||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||||
|
|
||||||
// Get smaller index if compared values are equal.
|
// Get smaller index if compared values are equal.
|
||||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
compareEqDirectionAttr, compareTypeAttr);
|
compareEqDirectionAttr, compareTypeAttr);
|
||||||
Value minIdx =
|
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
*secondIdxArg);
|
||||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||||
|
|
||||||
rewriter.create<mhlo::ReturnOp>(
|
rewriter.create<stablehlo::ReturnOp>(
|
||||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||||
// input
|
// input
|
||||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
|
|
||||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
std::copy(stride.begin(), stride.end(),
|
||||||
|
stablehloStride.begin() + inputRank - 2);
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
mhloKernelSize.begin() + inputRank - 2);
|
stablehloKernelSize.begin() + inputRank - 2);
|
||||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloKernelSize);
|
stablehloKernelSize);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloStride);
|
stablehloStride);
|
||||||
DenseIntElementsAttr baseDilations;
|
DenseIntElementsAttr baseDilations;
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloDilation);
|
stablehloDilation);
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPadding);
|
stablehloPadding);
|
||||||
|
|
||||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
baseDilations, windowDilations, pad);
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
rewriter.setInsertionPointToStart(&sumBlock);
|
rewriter.setInsertionPointToStart(&sumBlock);
|
||||||
|
|
||||||
Value sumResult =
|
Value sumResult =
|
||||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use kernel size as the divisor
|
// Use kernel size as the divisor
|
||||||
if (countIncludePad) {
|
if (countIncludePad) {
|
||||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
Value divisor = hlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.value();
|
.value();
|
||||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
divisor = hlo::promoteType(rewriter, divisor, outTy);
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
// Use another stablehlo.ReduceWindowOp to get the divisor
|
||||||
Value windowSizeConst =
|
Value windowSizeConst =
|
||||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
auto inputShapeVec =
|
auto inputShapeVec =
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
|
|
||||||
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||||
|
|
||||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
||||||
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
||||||
windowDilations, pad);
|
windowDilations, pad);
|
||||||
|
@ -522,11 +526,11 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
rewriter.setInsertionPointToStart(&sizeBlock);
|
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||||
|
|
||||||
Value sumResult =
|
Value sumResult =
|
||||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -560,33 +564,33 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
mhloKernelSize[dim] = inputShape[dim];
|
stablehloKernelSize[dim] = inputShape[dim];
|
||||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
mhloPadding[dim * 2] = inputShape[dim] - 1;
|
stablehloPadding[dim * 2] = inputShape[dim] - 1;
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloKernelSize);
|
stablehloKernelSize);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloStride);
|
stablehloStride);
|
||||||
DenseIntElementsAttr baseDilations;
|
DenseIntElementsAttr baseDilations;
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloDilation);
|
stablehloDilation);
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPadding);
|
stablehloPadding);
|
||||||
|
|
||||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
baseDilations, windowDilations, pad);
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
|
@ -604,17 +608,17 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
rewriter.setInsertionPointToStart(&sumBlock);
|
rewriter.setInsertionPointToStart(&sumBlock);
|
||||||
|
|
||||||
Value sumResult =
|
Value sumResult =
|
||||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
|
@ -0,0 +1,69 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||||
|
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||||
|
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace torch_to_stablehlo {
|
||||||
|
|
||||||
|
struct TorchToStablehloOptions {
|
||||||
|
bool enableStaticShape = false;
|
||||||
|
size_t dimSizeIndexBits = 64;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||||
|
const TorchToStablehloOptions &options)
|
||||||
|
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||||
|
this->options = options;
|
||||||
|
}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||||
|
}
|
||||||
|
const TorchToStablehloOptions &getOptions() const { return options; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TorchToStablehloOptions options;
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target,
|
||||||
|
const TorchToStablehloOptions &options);
|
||||||
|
void populateViewLikeOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
void populateGatherOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
void populateReductionOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
void populateLinearOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
|
||||||
|
void populatePoolingOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||||
|
|
||||||
|
} // namespace torch_to_stablehlo
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
|
@ -7,14 +7,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -25,7 +26,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/false)});
|
/*negative=*/false)});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
constType, {APFloat::getLargest(
|
constType, {APFloat::getLargest(
|
||||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*negative=*/true)});
|
/*negative=*/true)});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType,
|
constType,
|
||||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||||
constAttr);
|
constAttr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
Value initIndex;
|
Value initIndex;
|
||||||
if (dimSizeIndexBits == 32) {
|
if (dimSizeIndexBits == 32) {
|
||||||
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||||
} else {
|
} else {
|
||||||
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||||
|
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
|
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(inputShape,
|
RankedTensorType::get(inputShape,
|
||||||
rewriter.getIntegerType(dimSizeIndexBits)),
|
rewriter.getIntegerType(dimSizeIndexBits)),
|
||||||
inputShapeTensor, static_cast<uint64_t>(dim));
|
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||||
|
|
||||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
op->getLoc(), ValueRange{input, indexTensor},
|
op->getLoc(), ValueRange{input, indexTensor},
|
||||||
ValueRange{
|
ValueRange{
|
||||||
initValue,
|
initValue,
|
||||||
|
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
},
|
},
|
||||||
dimensions);
|
dimensions);
|
||||||
|
|
||||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||||
|
|
||||||
// Add block arguments
|
// Add block arguments
|
||||||
auto blockValArgumentType =
|
auto blockValArgumentType =
|
||||||
|
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
auto *secondValArg = std::next(firstIdxArg);
|
auto *secondValArg = std::next(firstIdxArg);
|
||||||
auto *secondIdxArg = std::next(secondValArg);
|
auto *secondIdxArg = std::next(secondValArg);
|
||||||
|
|
||||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
}
|
}
|
||||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
mhlo::ComparisonDirection::GE);
|
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
stablehlo::ComparisonDirectionAttr::get(
|
||||||
mhlo::ComparisonDirection::EQ);
|
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
compareGeDirectionAttr, compareTypeAttr);
|
compareGeDirectionAttr, compareTypeAttr);
|
||||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||||
|
|
||||||
// get smaller index value if compared nums are equal.
|
// get smaller index value if compared nums are equal.
|
||||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||||
compareEqDirectionAttr, compareTypeAttr);
|
compareEqDirectionAttr, compareTypeAttr);
|
||||||
Value minIdx =
|
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
*secondIdxArg);
|
||||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||||
|
|
||||||
rewriter.create<mhlo::ReturnOp>(
|
rewriter.create<stablehlo::ReturnOp>(
|
||||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||||
}
|
}
|
||||||
return mhloReduceOp.getResults();
|
return stablehloReduceOp.getResults();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
"AtenArgmaxOp to MHLO");
|
"AtenArgmaxOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
|
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
auto inputShapeInfo =
|
auto inputShapeInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||||
options.dimSizeIndexBits)
|
dim, options.dimSizeIndexBits)
|
||||||
.value();
|
.value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
|
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
|
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
|
||||||
outShapeTensor);
|
outShapeTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, mhloReduceResults[1]);
|
rewriter.replaceOp(op, stablehloReduceResults[1]);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
if (!inputElemTy.isIntOrFloat()) {
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
"AtenMaxDimOp to MHLO");
|
"AtenMaxDimOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType valResultType = getTypeConverter()
|
RankedTensorType valResultType = getTypeConverter()
|
||||||
|
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
|
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
auto inputShapeInfo =
|
auto inputShapeInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||||
options.dimSizeIndexBits)
|
dim, options.dimSizeIndexBits)
|
||||||
.value();
|
.value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
|
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
|
|
||||||
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
auto stablehloReduceValueResult =
|
||||||
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
|
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||||
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
|
outShapeTensor);
|
||||||
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
|
auto stablehloReduceIndexResult =
|
||||||
|
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
|
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||||
|
outShapeTensor);
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
|
rewriter.replaceOp(op,
|
||||||
|
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template dyn_cast<RankedTensorType>();
|
.template dyn_cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||||
// Use output element type as computation type.
|
// Use output element type as computation type.
|
||||||
auto dstElemTy = outTy.getElementType();
|
auto dstElemTy = outTy.getElementType();
|
||||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
input =
|
||||||
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
"AtenSumOp to MHLO");
|
"AtenSumOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> dims;
|
SmallVector<int64_t> dims;
|
||||||
|
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
Value initValue =
|
Value initValue =
|
||||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
if (!initValue) return failure();
|
if (!initValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
llvm::sort(dims.begin(), dims.end());
|
llvm::sort(dims.begin(), dims.end());
|
||||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||||
mhloReduceOp.getResults());
|
stablehloReduceOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
if (!inputElemTy.isIntOrFloat()) {
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
"AtenMaxOp to MHLO");
|
"AtenMaxOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> dims;
|
SmallVector<int64_t> dims;
|
||||||
|
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value initValue =
|
Value initValue =
|
||||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
if (!initValue) return failure();
|
if (!initValue)
|
||||||
|
return failure();
|
||||||
llvm::sort(dims.begin(), dims.end());
|
llvm::sort(dims.begin(), dims.end());
|
||||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
block.addArgument(blockArgumentTy, op->getLoc());
|
block.addArgument(blockArgumentTy, op->getLoc());
|
||||||
|
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
Value maxResult = rewriter.create<mhlo::MaxOp>(
|
Value maxResult = rewriter.create<stablehlo::MaxOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()),
|
op, getTypeConverter()->convertType(op.getType()),
|
||||||
mhloReduceOp.getResults());
|
stablehloReduceOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template dyn_cast<RankedTensorType>();
|
.template dyn_cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||||
// Use output element type as computation type.
|
// Use output element type as computation type.
|
||||||
auto dstElemTy = outTy.getElementType();
|
auto dstElemTy = outTy.getElementType();
|
||||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
input =
|
||||||
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||||
"AtenSumDimIntListOp to MHLO");
|
"AtenSumDimIntListOp to StableHLO");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> inputDims;
|
SmallVector<int64_t> inputDims;
|
||||||
|
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
Value initValue =
|
Value initValue =
|
||||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||||
if (!initValue) return failure();
|
if (!initValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
llvm::sort(dims.begin(), dims.end());
|
llvm::sort(dims.begin(), dims.end());
|
||||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
Region ®ion = mhloReduceOp.getBody();
|
Region ®ion = stablehloReduceOp.getBody();
|
||||||
Block &block = region.emplaceBlock();
|
Block &block = region.emplaceBlock();
|
||||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||||
|
|
||||||
|
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
|
auto outShapeInfo =
|
||||||
options.dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(outShapeInfo)) {
|
if (failed(outShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()),
|
op, getTypeConverter()->convertType(op.getType()),
|
||||||
mhloReduceOp.getResult(0), outShapeTensor);
|
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||||
mhloReduceOp.getResults());
|
stablehloReduceOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// AtenFrobeniusNormDimOp
|
// AtenFrobeniusNormDimOp
|
||||||
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
|
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
|
||||||
// + mhlo.sqrt
|
// dims)
|
||||||
|
// + stablehlo.sqrt
|
||||||
namespace {
|
namespace {
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
const TorchToMhloOptions &options = getOptions();
|
const TorchToStablehloOptions &options = getOptions();
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the dims in ascending order, making the conversion
|
// Sort the dims in ascending order, making the conversion
|
||||||
// stable with unordered dims.
|
// stable with unordered dims.
|
||||||
std::sort(dims.begin(), dims.end());
|
std::sort(dims.begin(), dims.end());
|
||||||
|
|
||||||
|
@ -624,14 +642,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
op, "non-const bool `keepdim` is not supported");
|
op, "non-const bool `keepdim` is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input);
|
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
|
||||||
|
|
||||||
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
||||||
if (!initValue) {
|
if (!initValue) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||||
op->getLoc(), squareOp.getResult(), initValue,
|
op->getLoc(), squareOp.getResult(), initValue,
|
||||||
rewriter.getI64TensorAttr(dims));
|
rewriter.getI64TensorAttr(dims));
|
||||||
|
|
||||||
|
@ -649,30 +667,32 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
|
|
||||||
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument,
|
auto addResult = rewriter.create<stablehlo::AddOp>(
|
||||||
secondArgument);
|
op->getLoc(), firstArgument, secondArgument);
|
||||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output =
|
auto output =
|
||||||
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
auto outShapeInfo =
|
||||||
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(outShapeInfo)) {
|
if (failed(outShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto outShapeVec = *outShapeInfo;
|
auto outShapeVec = *outShapeInfo;
|
||||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
op->getLoc(), rewriter.getIntegerAttr(
|
op->getLoc(),
|
||||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||||
for (int64_t i : dims) {
|
for (int64_t i : dims) {
|
||||||
outShapeVec[i] = one;
|
outShapeVec[i] = one;
|
||||||
}
|
}
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), output,
|
op, getTypeConverter()->convertType(op.getType()), output,
|
||||||
outShapeTensor);
|
outShapeTensor);
|
||||||
return success();
|
return success();
|
||||||
|
@ -682,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
|
@ -7,11 +7,12 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
@ -21,27 +22,27 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace hlo {
|
||||||
|
|
||||||
// Create a 32-bit float constant operator from a float
|
// Create a 32-bit float constant operator from a float
|
||||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
float val) {
|
float val) {
|
||||||
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a 64-bit float constant operator from a double
|
// Create a 64-bit float constant operator from a double
|
||||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||||
double val) {
|
double val) {
|
||||||
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||||
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
||||||
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||||
|
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
||||||
auto const_type = RankedTensorType::get(dshape, dtype);
|
auto const_type = RankedTensorType::get(dshape, dtype);
|
||||||
auto const_attr = SplatElementsAttr::get(const_type, val);
|
auto const_attr = SplatElementsAttr::get(const_type, val);
|
||||||
auto const_op =
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
op->getLoc(), const_type, const_attr);
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
Value scalarValue, Type dtype) {
|
Operation *op, Value scalarValue, Type dtype) {
|
||||||
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
||||||
op->getLoc(), ArrayRef<Value>{scalarValue});
|
op->getLoc(), ArrayRef<Value>{scalarValue});
|
||||||
auto dtype_tensor =
|
auto dtype_tensor =
|
||||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||||
return rewriter.create<mhlo::ReshapeOp>(
|
return rewriter.create<stablehlo::ReshapeOp>(
|
||||||
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
||||||
dtype_tensor);
|
dtype_tensor);
|
||||||
}
|
}
|
||||||
|
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||||
if (in_type.getElementType() != outType.getElementType()) {
|
if (in_type.getElementType() != outType.getElementType()) {
|
||||||
TensorType promotedType =
|
TensorType promotedType =
|
||||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||||
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
|
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
|
||||||
|
input);
|
||||||
}
|
}
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
if (in_type.getElementType() != outType.getElementType()) {
|
if (in_type.getElementType() != outType.getElementType()) {
|
||||||
TensorType promoted_type =
|
TensorType promoted_type =
|
||||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||||
input =
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
|
||||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
input);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> inShape = in_type.getShape();
|
ArrayRef<int64_t> inShape = in_type.getShape();
|
||||||
|
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
bcastDims);
|
bcastDims);
|
||||||
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
|
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||||
input, bcast_attr);
|
op->getLoc(), outType, input, bcast_attr);
|
||||||
return bcast_op.getResult();
|
return bcast_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||||
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||||
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType) {
|
TensorType outType) {
|
||||||
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
||||||
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
|
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
|
||||||
return rewriter
|
return rewriter
|
||||||
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
|
.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
rewriter.getI64TensorAttr({}))
|
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
} // namespace mhlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
@ -18,22 +18,22 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace hlo {
|
||||||
|
|
||||||
using mlir::ConversionPatternRewriter;
|
using mlir::ConversionPatternRewriter;
|
||||||
|
|
||||||
// Create a 32-bit float constant operator from a float
|
// Create a 32-bit float constant operator from a float
|
||||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
float val);
|
float val);
|
||||||
|
|
||||||
// Create a 64-bit float constant operator from a double
|
// Create a 64-bit float constant operator from a double
|
||||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||||
double val);
|
double val);
|
||||||
|
|
||||||
// Templated function to create a constant op for given type and shape.
|
// Templated function to create a constant op for given type and shape.
|
||||||
// T: storage C type.
|
// T: storage C type.
|
||||||
// Default template creates a constant tensor in T.
|
// Default template creates a constant tensor in T.
|
||||||
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
|
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||||
|
@ -42,8 +42,8 @@ template <typename T>
|
||||||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
||||||
|
|
||||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
Value scalarValue, Type dtype);
|
Operation *op, Value scalarValue, Type dtype);
|
||||||
|
|
||||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType);
|
TensorType outType);
|
||||||
} // namespace mhlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
|
@ -7,17 +7,18 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
class ConvertTorchToStablehlo
|
||||||
|
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
|
||||||
public:
|
public:
|
||||||
ConvertTorchToMhlo() = default;
|
ConvertTorchToStablehlo() = default;
|
||||||
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
|
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
|
||||||
this->enableStaticShape = enableStaticShape;
|
this->enableStaticShape = enableStaticShape;
|
||||||
this->enableI32Index = enableI32Index;
|
this->enableI32Index = enableI32Index;
|
||||||
}
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<chlo::ChloDialect>();
|
registry.insert<chlo::ChloDialect>();
|
||||||
registry.insert<mhlo::MhloDialect>();
|
registry.insert<stablehlo::StablehloDialect>();
|
||||||
registry.insert<tensor::TensorDialect>();
|
registry.insert<tensor::TensorDialect>();
|
||||||
registry.insert<arith::ArithDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
@ -48,7 +50,7 @@ public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
||||||
tensor::TensorDialect, arith::ArithDialect>();
|
tensor::TensorDialect, arith::ArithDialect>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
@ -57,20 +59,20 @@ public:
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
|
torch_to_stablehlo::TorchToStablehloOptions options{
|
||||||
enableI32Index ? 32u : 64u};
|
enableStaticShape, enableI32Index ? 32u : 64u};
|
||||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
target, options);
|
|
||||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||||
|
typeConverter, patterns, target, options);
|
||||||
|
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||||
|
typeConverter, patterns, target, options);
|
||||||
|
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||||
|
typeConverter, patterns, target, options);
|
||||||
|
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
typeConverter, patterns, target, options);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
|
||||||
target, options);
|
|
||||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
|
||||||
target, options);
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
@ -82,13 +84,13 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToMhloPass() {
|
mlir::torch::createConvertTorchToStablehloPass() {
|
||||||
return std::make_unique<ConvertTorchToMhlo>(false, false);
|
return std::make_unique<ConvertTorchToStablehlo>(false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
|
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
|
||||||
bool enableI32Index) {
|
bool enableI32Index) {
|
||||||
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
|
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
|
||||||
enableI32Index);
|
enableI32Index);
|
||||||
}
|
}
|
|
@ -7,14 +7,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "./MhloLegalizeUtils.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "./PopulatePatterns.h"
|
#include "StablehloLegalizeUtils.h"
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -28,7 +29,7 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
using namespace mlir::torch::torch_to_mhlo;
|
using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||||
|
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
||||||
auto stridesTensor =
|
auto stridesTensor =
|
||||||
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
||||||
|
|
||||||
return rewriter.create<mhlo::RealDynamicSliceOp>(
|
return rewriter.create<stablehlo::RealDynamicSliceOp>(
|
||||||
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||||
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
||||||
}
|
}
|
||||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||||
if (failed(dimSizesInfo))
|
if (failed(dimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -179,7 +180,7 @@ public:
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto newRank = dimSizes.size();
|
auto newRank = dimSizes.size();
|
||||||
if (newRank == 0 || rankType.getRank() == 0) {
|
if (newRank == 0 || rankType.getRank() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()),
|
||||||
|
@ -214,17 +215,18 @@ public:
|
||||||
numel);
|
numel);
|
||||||
|
|
||||||
if (dimSizes.size() == 0) {
|
if (dimSizes.size() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()),
|
||||||
adaptor.getSelf());
|
adaptor.getSelf());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
Value stablehloShape =
|
||||||
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||||
loc, mhloShape.getType(), numel, mhloShape);
|
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
loc, stablehloShape.getType(), numel, stablehloShape);
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()),
|
||||||
|
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
dims.push_back(r);
|
dims.push_back(r);
|
||||||
}
|
}
|
||||||
if (dims.size() == 0) {
|
if (dims.size() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||||
options.dimSizeIndexBits);
|
options.dimSizeIndexBits);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
auto newDimSizes = *newDimSizesInfo;
|
auto newDimSizes = *newDimSizesInfo;
|
||||||
auto mhloShape =
|
auto stablehloShape =
|
||||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
dims.erase(dims.begin() + dim);
|
dims.erase(dims.begin() + dim);
|
||||||
if (dims.size() == 0) {
|
if (dims.size() == 0) {
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||||
options.dimSizeIndexBits);
|
options.dimSizeIndexBits);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
auto newDimSizes = *newDimSizesInfo;
|
auto newDimSizes = *newDimSizesInfo;
|
||||||
auto mhloShape =
|
auto stablehloShape =
|
||||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("dim must be a Scalar constant");
|
return op->emitError("dim must be a Scalar constant");
|
||||||
|
|
||||||
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||||
{dim}, options.dimSizeIndexBits);
|
{dim}, options.dimSizeIndexBits);
|
||||||
if (failed(unsqzTensorInfo))
|
if (failed(unsqzTensorInfo))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"failed to create unsqueezed tensor");
|
"failed to create unsqueezed tensor");
|
||||||
|
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
|
||||||
TorchMLIRTorchConversionToMLProgram
|
TorchMLIRTorchConversionToMLProgram
|
||||||
MLIRMemRefTransforms)
|
MLIRMemRefTransforms)
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
list(APPEND LinkedLibs ChloPasses)
|
list(APPEND LinkedLibs ChloPasses)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
VerifyLinalgOnTensorsBackendContract.cpp
|
VerifyLinalgOnTensorsBackendContract.cpp
|
||||||
VerifyTosaBackendContract.cpp
|
VerifyTosaBackendContract.cpp
|
||||||
VerifyMhloBackendContract.cpp
|
VerifyStablehloBackendContract.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||||
|
|
|
@ -21,9 +21,8 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "mhlo/transforms/passes.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
|
||||||
#endif
|
#endif
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
|
||||||
"Pipeline lowering torch backend contract to TOSA backend "
|
"Pipeline lowering torch backend contract to TOSA backend "
|
||||||
"contract.",
|
"contract.",
|
||||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
|
mlir::PassPipelineRegistration<
|
||||||
"torch-backend-to-mhlo-backend-pipeline",
|
TorchConversion::StablehloBackendPipelineOptions>(
|
||||||
"Pipeline lowering torch backend contract to MHLO backend "
|
"torch-backend-to-stablehlo-backend-pipeline",
|
||||||
|
"Pipeline lowering torch backend contract to StableHLO backend "
|
||||||
"contract.",
|
"contract.",
|
||||||
TorchConversion::createTorchBackendToMhloBackendPipeline);
|
TorchConversion::createTorchBackendToStablehloBackendPipeline);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
||||||
OpPassManager &pm,
|
OpPassManager &pm,
|
||||||
const TorchConversion::MhloBackendPipelineOptions &options) {
|
const TorchConversion::StablehloBackendPipelineOptions &options) {
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
|
// Generate Stablehlo ops.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
||||||
options.enableStaticShape, options.enableI32Index));
|
options.enableStaticShape, options.enableI32Index));
|
||||||
|
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
|
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
|
|
||||||
// Convert CHLO ops to MHLO ops
|
|
||||||
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
|
|
||||||
// Clean up any non-canonical code introduced above..
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// MHLO backend contract.
|
// StableHLO backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
// Verify that we have lowered to the form that MHLO backends
|
|
||||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
// Verify that we have lowered to Stablehlo and Chlo ops.
|
||||||
// correct form.
|
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||||
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mhlo/IR/hlo_ops.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
@ -18,6 +17,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -25,17 +25,15 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class VerifyMhloBackendContractPass
|
class VerifyStablehloBackendContractPass
|
||||||
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
|
: public VerifyStablehloBackendContractBase<
|
||||||
|
VerifyStablehloBackendContractPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
|
||||||
auto module = getOperation();
|
|
||||||
TypeConverter converter;
|
TypeConverter converter;
|
||||||
converter.addConversion([](Type type) -> Type {
|
converter.addConversion([](Type type) -> Type {
|
||||||
auto elemTy = type;
|
auto elemTy = type;
|
||||||
if (isa<TensorType>(type)) {
|
if (isa<TensorType>(type))
|
||||||
elemTy = type.cast<TensorType>().getElementType();
|
elemTy = type.cast<TensorType>().getElementType();
|
||||||
}
|
|
||||||
if (BaseMemRefType::isValidElementType(elemTy))
|
if (BaseMemRefType::isValidElementType(elemTy))
|
||||||
return type;
|
return type;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
|
||||||
|
|
||||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||||
|
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
// Structural operations.
|
// Structural operations.
|
||||||
|
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
|
||||||
// Shape operations.
|
// Shape operations.
|
||||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||||
|
|
||||||
target.addLegalDialect<mhlo::MhloDialect>();
|
|
||||||
target.addLegalDialect<chlo::ChloDialect>();
|
target.addLegalDialect<chlo::ChloDialect>();
|
||||||
|
target.addLegalDialect<stablehlo::StablehloDialect>();
|
||||||
target.addLegalDialect<tensor::TensorDialect>();
|
target.addLegalDialect<tensor::TensorDialect>();
|
||||||
target.addLegalDialect<arith::ArithDialect>();
|
target.addLegalDialect<arith::ArithDialect>();
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
|
||||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
|
||||||
// doesn't unnecessarily spew out the entire module.
|
|
||||||
emitError(module.getLoc())
|
|
||||||
<< "Module does not conform to the MHLO backend contract. "
|
|
||||||
"See dialect conversion legality information above.";
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
|
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
|
||||||
return std::make_unique<VerifyMhloBackendContractPass>();
|
return std::make_unique<VerifyStablehloBackendContractPass>();
|
||||||
}
|
}
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
@ -20,6 +20,10 @@
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
#include "torch-mlir/RefBackend/Passes.h"
|
#include "torch-mlir/RefBackend/Passes.h"
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
#include "mhlo/transforms/passes.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
registry.insert<mlir::func::FuncDialect>();
|
registry.insert<mlir::func::FuncDialect>();
|
||||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||||
|
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
|
||||||
mlir::torch::registerConversionPasses();
|
mlir::torch::registerConversionPasses();
|
||||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||||
mlir::torch::TMTensor::registerPasses();
|
mlir::torch::TMTensor::registerPasses();
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
||||||
|
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||||
|
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||||
|
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,9 +44,9 @@ class OutputType(Enum):
|
||||||
# as taking the `TORCH` output type and lowering it to TOSA.
|
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||||
TOSA = "tosa"
|
TOSA = "tosa"
|
||||||
|
|
||||||
# This output type consists of `mhlo` dialect ops. It can be thought of
|
# This output type consists of `stablehlo` dialect ops. It can be thought of
|
||||||
# as taking the `TORCH` output type and lowering it to MHLO.
|
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||||
MHLO = "mhlo"
|
STABLEHLO = "stablehlo"
|
||||||
|
|
||||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||||
# for end-users, but can be convenient for development or reporting bugs.
|
# for end-users, but can be convenient for development or reporting bugs.
|
||||||
|
@ -242,7 +242,7 @@ class ExampleArgs:
|
||||||
BACKEND_LEGAL_OPS = {
|
BACKEND_LEGAL_OPS = {
|
||||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||||
OutputType.MHLO: [],
|
OutputType.STABLEHLO: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
|
||||||
|
|
||||||
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
||||||
# output type because the other output types actually invoke their
|
# output type because the other output types actually invoke their
|
||||||
# respective backends (Linalg, TOSA, or MHLO), and those backends have
|
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
|
||||||
# very specific requirements about the ops which are legal.
|
# very specific requirements about the ops which are legal.
|
||||||
# See `BACKEND_LEGAL_OPS` for more details.
|
# See `BACKEND_LEGAL_OPS` for more details.
|
||||||
if backend_legal_ops is not None:
|
if backend_legal_ops is not None:
|
||||||
|
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
print(mb.module)
|
print(mb.module)
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
elif output_type == OutputType.MHLO:
|
elif output_type == OutputType.STABLEHLO:
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
"builtin.module(torch-backend-to-mhlo-backend-pipeline)",
|
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||||
"Lowering Torch Backend IR -> MHLO Backend IR")
|
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||||
if verbose:
|
if verbose:
|
||||||
print("\n====================")
|
print("\n====================")
|
||||||
print("MHLO Backend IR")
|
print("StableHLO Backend IR")
|
||||||
print(mb.module)
|
print(mb.module)
|
||||||
return mb.module
|
return mb.module
|
||||||
raise Exception(f"Unknown OutputType: {output_type}")
|
raise Exception(f"Unknown OutputType: {output_type}")
|
||||||
|
|
|
@ -7,6 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
|
||||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
from .mhlo_backend import MhloBackendTestConfig
|
from .stablehlo_backend import StablehloBackendTestConfig
|
||||||
from .tosa_backend import TosaBackendTestConfig
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
from .torchdynamo import TorchDynamoTestConfig
|
from .torchdynamo import TorchDynamoTestConfig
|
||||||
|
|
|
@ -8,12 +8,8 @@ from typing import Any
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
|
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
|
||||||
from torch_mlir_e2e_test.framework import (
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
TestConfig,
|
|
||||||
Trace,
|
|
||||||
TraceItem
|
|
||||||
)
|
|
||||||
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
||||||
from .utils import (
|
from .utils import (
|
||||||
recursively_convert_to_numpy,
|
recursively_convert_to_numpy,
|
||||||
|
@ -21,20 +17,20 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MhloBackendTestConfig(TestConfig):
|
class StablehloBackendTestConfig(TestConfig):
|
||||||
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
|
||||||
|
|
||||||
This class handles all the common lowering that torch-mlir does before
|
This class handles all the common lowering that torch-mlir does before
|
||||||
reaching the linalg-on-tensors abstraction level.
|
reaching the linalg-on-tensors abstraction level.
|
||||||
"""
|
"""
|
||||||
def __init__(self, backend: MhloBackend):
|
|
||||||
|
def __init__(self, backend: StablehloBackend):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torch_mlir.compile(
|
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
|
||||||
program, example_args, output_type="mhlo")
|
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
||||||
|
@ -46,7 +42,6 @@ class MhloBackendTestConfig(TestConfig):
|
||||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
output = recursively_convert_from_numpy(outputs)
|
output = recursively_convert_from_numpy(outputs)
|
||||||
result.append(
|
result.append(
|
||||||
TraceItem(symbol=item.symbol,
|
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||||
inputs=item.inputs,
|
)
|
||||||
output=output))
|
|
||||||
return result
|
return result
|
|
@ -10,29 +10,30 @@ import torch
|
||||||
|
|
||||||
from torch_mlir.ir import Module
|
from torch_mlir.ir import Module
|
||||||
|
|
||||||
# A type shared between the result of `MhloBackend.compile` and the
|
# A type shared between the result of `StablehloBackend.compile` and the
|
||||||
# input to `MhloBackend.load`. Each backend will likely have a
|
# input to `StablehloBackend.load`. Each backend will likely have a
|
||||||
# different definition of this type.
|
# different definition of this type.
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
CompiledArtifact = TypeVar("CompiledArtifact")
|
||||||
|
|
||||||
# A wrapper around a backend-specific loaded program representation
|
# A wrapper around a backend-specific loaded program representation
|
||||||
# that uniformly translates the `x.method(...)` interface expected of
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
# Torch modules into appropriate lower-level operations.
|
# Torch modules into appropriate lower-level operations.
|
||||||
Invoker = TypeVar('Invoker')
|
Invoker = TypeVar("Invoker")
|
||||||
|
|
||||||
|
|
||||||
class MhloBackend(abc.ABC):
|
class StablehloBackend(abc.ABC):
|
||||||
"""The interface to an MHLO backend.
|
"""The interface to an StableHLO backend.
|
||||||
|
|
||||||
Backends are recommended to raise meaningful exceptions in case of error,
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
ideally with easy reproduction instructions.
|
ideally with easy reproduction instructions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compile(self, module: Module) -> CompiledArtifact:
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
"""Compile the provided MLIR module into a compiled artifact.
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
||||||
The module adheres to the MHLO backend contract
|
The module adheres to the StableHLO backend contract
|
||||||
(see the VerifyMhloBackendContract pass).
|
(see the VerifyStablehloBackendContract pass).
|
||||||
|
|
||||||
The compiled artifact can be any type, but must be correctly
|
The compiled artifact can be any type, but must be correctly
|
||||||
interpreted by the `load` method.
|
interpreted by the `load` method.
|
|
@ -7,28 +7,32 @@ from torch_mlir.ir import *
|
||||||
from torch_mlir.passmanager import *
|
from torch_mlir.passmanager import *
|
||||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||||
|
RefBackendLinalgOnTensorsBackend,
|
||||||
|
)
|
||||||
|
|
||||||
from .abc import MhloBackend
|
from .abc import StablehloBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LinalgOnTensorsMhloBackend",
|
"LinalgOnTensorsStablehloBackend",
|
||||||
]
|
]
|
||||||
|
|
||||||
class LinalgOnTensorsMhloBackend(MhloBackend):
|
|
||||||
"""Main entry-point for the linalg-on-tensors based MHLO backend.
|
class LinalgOnTensorsStablehloBackend(StablehloBackend):
|
||||||
|
"""Main entry-point for the linalg-on-tensors based StableHLO backend.
|
||||||
|
|
||||||
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
|
||||||
def compile(self, imported_module: Module):
|
def compile(self, imported_module: Module):
|
||||||
"""Compiles an imported module that satisfied the MHLO backend contract.
|
"""Compiles an imported module that satisfied the StableHLO backend contract.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
imported_module: The MLIR module consisting of funcs in the MHLO
|
imported_module: The MLIR module consisting of funcs in the StableHLO
|
||||||
dialect.
|
dialect.
|
||||||
Returns:
|
Returns:
|
||||||
An opaque, backend specific compiled artifact object that can be
|
An opaque, backend specific compiled artifact object that can be
|
||||||
|
@ -36,8 +40,9 @@ class LinalgOnTensorsMhloBackend(MhloBackend):
|
||||||
"""
|
"""
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
"builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))",
|
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
|
||||||
"Lowering MLIR-HLO to Linalg-on-Tensors")
|
"Lowering StableHLO to Linalg-on-Tensors",
|
||||||
|
)
|
||||||
return self.refbackend.compile(imported_module)
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
def load(self, module):
|
def load(self, module):
|
|
@ -1,7 +1,7 @@
|
||||||
llvm_canonicalize_cmake_booleans(
|
llvm_canonicalize_cmake_booleans(
|
||||||
MLIR_ENABLE_BINDINGS_PYTHON
|
MLIR_ENABLE_BINDINGS_PYTHON
|
||||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
|
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
|
||||||
TORCH_MLIR_ENABLE_MHLO
|
TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
)
|
)
|
||||||
|
|
||||||
configure_lit_site_cfg(
|
configure_lit_site_cfg(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -7,7 +7,7 @@
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
|
||||||
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||||
|
@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||||
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64>
|
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64>
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
|
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
|
||||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
|
||||||
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||||
|
@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
|
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
|
||||||
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
|
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
|
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
|
||||||
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
|
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
|
||||||
|
@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
|
||||||
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32>
|
// CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
|
||||||
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||||
|
@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc
|
||||||
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
|
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32>
|
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex>
|
// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex>
|
||||||
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
|
||||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32>
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32>
|
||||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32>
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32>
|
||||||
func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> {
|
func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> {
|
||||||
|
@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
|
||||||
// CHECK-LABEL: func.func @torch.aten.batch_norm$training(
|
// CHECK-LABEL: func.func @torch.aten.batch_norm$training(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %true = torch.constant.bool true
|
// CHECK: %true = torch.constant.bool true
|
||||||
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||||
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
|
@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
|
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
|
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
|
||||||
|
@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
||||||
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
|
@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %none = torch.constant.none
|
// CHECK: %none = torch.constant.none
|
||||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %true = torch.constant.bool true
|
// CHECK: %true = torch.constant.bool true
|
||||||
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
||||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||||
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
|
||||||
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
|
||||||
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
|
@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
|
||||||
// CHECK-LABEL: func @torch.aten.native_layer_norm(
|
// CHECK-LABEL: func @torch.aten.native_layer_norm(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32>
|
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32>
|
||||||
// CHECK: %int4 = torch.constant.int 4
|
// CHECK: %int4 = torch.constant.int 4
|
||||||
// CHECK: %int5 = torch.constant.int 5
|
// CHECK: %int5 = torch.constant.int 5
|
||||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
||||||
// CHECK: %true = torch.constant.bool true
|
// CHECK: %true = torch.constant.bool true
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
|
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64>
|
||||||
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
|
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
|
||||||
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
|
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
|
||||||
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
|
||||||
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
// CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||||
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
|
||||||
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
|
// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
|
||||||
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
|
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
|
||||||
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
|
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
|
||||||
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
|
||||||
|
@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
|
||||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
|
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc
|
||||||
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
|
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.gelu(
|
// CHECK-LABEL: func.func @torch.aten.gelu(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -7,12 +7,12 @@
|
||||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
|
// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
|
// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
||||||
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
|
||||||
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
// CHECK-LABEL: func.func @torch.aten.log$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
||||||
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
// CHECK-LABEL: func.func @torch.aten.exp$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
||||||
// CHECK-LABEL: func.func @torch.aten.neg$basic(
|
// CHECK-LABEL: func.func @torch.aten.neg$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
||||||
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
|
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
|
||||||
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
|
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||||
|
@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
|
||||||
|
@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
|
||||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32>
|
||||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
|
||||||
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
|
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
|
||||||
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
|
||||||
|
@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
|
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
|
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||||
|
@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
|
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
|
||||||
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||||
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.index_select$basic(
|
// CHECK-LABEL: func.func @torch.aten.index_select$basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
||||||
|
@ -10,8 +10,8 @@
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
|
||||||
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
|
||||||
|
@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x?xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x1x?xf32>
|
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||||
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
|
||||||
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {
|
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.mm$basic$static(
|
// CHECK-LABEL: func.func @torch.aten.mm$basic$static(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
||||||
|
@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
|
||||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
|
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
|
||||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
||||||
|
@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
|
||||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
|
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
||||||
|
@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
|
||||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
|
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
|
||||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
||||||
|
@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
|
||||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
|
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
|
||||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
||||||
|
@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
|
||||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
|
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
|
||||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
|
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
||||||
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
||||||
|
@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
|
||||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
|
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
|
||||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
|
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||||
|
@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||||
|
@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||||
|
@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
||||||
|
@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK-LABEL: func.func @torch.aten.matmul$proj(
|
// CHECK-LABEL: func.func @torch.aten.matmul$proj(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
|
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||||
|
@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
|
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
|
||||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
||||||
|
@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc
|
||||||
// CHECK-LABEL: func.func @torch.aten.mm$proj(
|
// CHECK-LABEL: func.func @torch.aten.mm$proj(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
||||||
|
@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
|
||||||
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
// CHECK: %[[T_13:.*]] = torch.constant.bool false
|
// CHECK: %[[T_13:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
// CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
|
||||||
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
// CHECK: %false = torch.constant.bool false
|
// CHECK: %false = torch.constant.bool false
|
||||||
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
|
||||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
|
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
|
||||||
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
|
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
|
||||||
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
|
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
|
||||||
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
|
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
|
||||||
// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
||||||
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
|
||||||
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
|
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||||
// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]])
|
// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]])
|
||||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
|
||||||
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
|
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
|
||||||
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
|
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
|
||||||
|
@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||||
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
|
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
|
||||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||||
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
||||||
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
|
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
|
||||||
|
@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
|
// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
|
||||||
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
|
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
|
||||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||||
// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
|
// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
|
||||||
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
|
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
|
||||||
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
|
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
|
||||||
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
|
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
|
||||||
|
@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32>
|
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32>
|
||||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
|
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
|
||||||
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
|
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
|
||||||
|
@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
||||||
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
|
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
|
||||||
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
|
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
|
||||||
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
|
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
|
||||||
// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
|
// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
|
||||||
// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
|
// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
|
||||||
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
|
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
|
||||||
// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
|
// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
|
||||||
// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]])
|
// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]])
|
||||||
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
|
||||||
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
|
||||||
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>
|
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
if not config.enable_mhlo:
|
if not config.enable_stablehlo:
|
||||||
config.unsupported = True
|
config.unsupported = True
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -13,11 +13,11 @@
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||||
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||||
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||||
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
|
||||||
// CHECK: })
|
// CHECK: })
|
||||||
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
|
// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
|
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||||
|
@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
|
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
|
||||||
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
|
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
|
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
|
||||||
// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor<i64>
|
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
|
||||||
// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
|
||||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
|
||||||
// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
|
||||||
// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
// CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
// CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
|
||||||
// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
|
||||||
// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
|
||||||
// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
|
||||||
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
|
||||||
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
|
||||||
|
@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||||
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
|
||||||
// CHECK: mhlo.return %[[IVAL_2]] : tensor<f32>
|
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
|
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
|
||||||
|
@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
|
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
|
||||||
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
|
||||||
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
|
||||||
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
|
||||||
// CHECK: mhlo.return %[[IVAL_5]] : tensor<f32>
|
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
|
||||||
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
|
||||||
// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
|
||||||
// CHECK: mhlo.return %[[T10]] : tensor<f32>
|
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
|
||||||
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor<i64>
|
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
|
||||||
// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
|
||||||
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
|
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -42,7 +42,7 @@
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
|
||||||
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
|
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
|
||||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
|
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
|
||||||
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
|
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
|
||||||
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
|
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
|
||||||
|
@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
|
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
|
||||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||||
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
|
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
|
||||||
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
|
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
|
||||||
|
@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
|
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
|
||||||
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
|
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
|
||||||
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
|
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
|
||||||
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
|
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
|
||||||
|
@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
|
||||||
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
|
||||||
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||||
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
|
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
|
||||||
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
|
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
|
||||||
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
|
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
|
||||||
|
@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
||||||
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
|
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
|
||||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
|
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
|
||||||
// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
||||||
// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
||||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
|
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
|
||||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
|
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
|
||||||
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
|
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
|
||||||
|
@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
|
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
|
||||||
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
|
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
||||||
// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
||||||
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
|
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
|
||||||
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
|
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
|
||||||
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
|
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
|
||||||
|
@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
|
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
|
||||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
|
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
|
||||||
func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
|
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
|
||||||
func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
|
||||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
|
||||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
|
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
|
||||||
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
||||||
|
@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
|
||||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
|
||||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
|
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
|
||||||
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
||||||
|
@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
|
||||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
||||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
|
||||||
// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
|
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
|
||||||
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
|
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
|
||||||
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
||||||
|
@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
|
||||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
|
||||||
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
|
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
|
||||||
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
||||||
|
@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
|
||||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
|
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
|
||||||
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
||||||
|
@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
|
||||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
|
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
|
||||||
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
||||||
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
|
||||||
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
|
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
|
||||||
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
||||||
|
|
|
@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@"
|
||||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||||
config.python_executable = "@Python3_EXECUTABLE@"
|
config.python_executable = "@Python3_EXECUTABLE@"
|
||||||
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
|
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
|
||||||
config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@
|
config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@
|
||||||
|
|
||||||
import lit.llvm
|
import lit.llvm
|
||||||
lit.llvm.initialize(lit_config, config)
|
lit.llvm.initialize(lit_config, config)
|
||||||
|
|
|
@ -268,7 +268,7 @@ gentbl_cc_library(
|
||||||
(
|
(
|
||||||
[
|
[
|
||||||
"-gen-pass-decls",
|
"-gen-pass-decls",
|
||||||
"-DTORCH_MLIR_ENABLE_MHLO",
|
"-DTORCH_MLIR_ENABLE_STABLEHLO",
|
||||||
],
|
],
|
||||||
"include/torch-mlir/Conversion/Passes.h.inc",
|
"include/torch-mlir/Conversion/Passes.h.inc",
|
||||||
),
|
),
|
||||||
|
@ -434,13 +434,13 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "TorchMLIRTorchToMhlo",
|
name = "TorchMLIRTorchToStablehlo",
|
||||||
srcs = glob([
|
srcs = glob([
|
||||||
"lib/Conversion/*.h",
|
"lib/Conversion/*.h",
|
||||||
"lib/Conversion/TorchToMhlo/*.h",
|
"lib/Conversion/TorchToStablehlo/*.h",
|
||||||
"lib/Conversion/TorchToMhlo/*.cpp",
|
"lib/Conversion/TorchToStablehlo/*.cpp",
|
||||||
]),
|
]),
|
||||||
hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]),
|
hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]),
|
||||||
strip_include_prefix = "include",
|
strip_include_prefix = "include",
|
||||||
deps = [
|
deps = [
|
||||||
":TorchMLIRConversionPassesIncGen",
|
":TorchMLIRConversionPassesIncGen",
|
||||||
|
@ -465,8 +465,8 @@ cc_library(
|
||||||
":TorchMLIRTorchConversionToMLProgram",
|
":TorchMLIRTorchConversionToMLProgram",
|
||||||
":TorchMLIRTorchToArith",
|
":TorchMLIRTorchToArith",
|
||||||
":TorchMLIRTorchToLinalg",
|
":TorchMLIRTorchToLinalg",
|
||||||
":TorchMLIRTorchToMhlo",
|
|
||||||
":TorchMLIRTorchToSCF",
|
":TorchMLIRTorchToSCF",
|
||||||
|
":TorchMLIRTorchToStablehlo",
|
||||||
":TorchMLIRTorchToTMTensor",
|
":TorchMLIRTorchToTMTensor",
|
||||||
":TorchMLIRTorchToTosa",
|
":TorchMLIRTorchToTosa",
|
||||||
],
|
],
|
||||||
|
@ -489,8 +489,8 @@ cc_library(
|
||||||
":TorchMLIRTorchPasses",
|
":TorchMLIRTorchPasses",
|
||||||
":TorchMLIRTorchToArith",
|
":TorchMLIRTorchToArith",
|
||||||
":TorchMLIRTorchToLinalg",
|
":TorchMLIRTorchToLinalg",
|
||||||
":TorchMLIRTorchToMhlo",
|
|
||||||
":TorchMLIRTorchToSCF",
|
":TorchMLIRTorchToSCF",
|
||||||
|
":TorchMLIRTorchToStablehlo",
|
||||||
":TorchMLIRTorchToTMTensor",
|
":TorchMLIRTorchToTMTensor",
|
||||||
":TorchMLIRTorchToTosa",
|
":TorchMLIRTorchToTosa",
|
||||||
"@llvm-project//mlir:ConversionPasses",
|
"@llvm-project//mlir:ConversionPasses",
|
||||||
|
|
|
@ -23,7 +23,7 @@ expand_template(
|
||||||
# All disabled, but required to substituted because they are not in quotes.
|
# All disabled, but required to substituted because they are not in quotes.
|
||||||
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
|
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
|
||||||
"@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0",
|
"@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0",
|
||||||
"@TORCH_MLIR_ENABLE_MHLO@": "0",
|
"@TORCH_MLIR_ENABLE_STABLEHLO@": "0",
|
||||||
},
|
},
|
||||||
template = "lit.site.cfg.py.in",
|
template = "lit.site.cfg.py.in",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue