mhlo: migrate conversion to stablehlo (#1840)

This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
pull/1851/head
Ashay Rane 2023-02-02 07:29:47 -06:00 committed by GitHub
parent ed9d8d1fb7
commit 711646d095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1190 additions and 1136 deletions

View File

@ -113,7 +113,7 @@ jobs:
-DLLVM_USE_HOST_TOOLS=ON \ -DLLVM_USE_HOST_TOOLS=ON \
-DLLVM_ENABLE_ZSTD=OFF \ -DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=OFF \ -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \ -DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
-DMACOSX_DEPLOYMENT_TARGET=12.0 \ -DMACOSX_DEPLOYMENT_TARGET=12.0 \

View File

@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro() endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO) add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif() endif()
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
@ -128,8 +128,8 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif() endif()
if (TORCH_MLIR_ENABLE_MHLO) if (TORCH_MLIR_ENABLE_STABLEHLO)
set(MHLO_BUILD_EMBEDDED ON) set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)

View File

@ -267,8 +267,8 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests" echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v python -m e2e_testing.main --config=linalg -v
echo ":::: Run MHLO e2e integration tests" echo ":::: Run StableHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v python -m e2e_testing.main --config=stablehlo -v
echo ":::: Run TOSA e2e integration tests" echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v python -m e2e_testing.main --config=tosa -v

View File

@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.) - Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo) - [StableHLO](https://github.com/openxla/stablehlo)
The terms "frontend" and "backend" are highly overloaded in any compiler The terms "frontend" and "backend" are highly overloaded in any compiler
project, but frequently in Torch-MLIR this is the meaning that they have. project, but frequently in Torch-MLIR this is the meaning that they have.
Sometimes "frontend" can mean something even further up the stack, such as Sometimes "frontend" can mean something even further up the stack, such as
something in PyTorch itself. When there is ambiguity we will refer to this as something in PyTorch itself. When there is ambiguity we will refer to this as
"at the PyTorch level". Similarly, "backend" can sometimes refer to something "at the PyTorch level". Similarly, "backend" can sometimes refer to something
sitting below Linalg-on-Tensors, TOSA, or MHLO. sitting below Linalg-on-Tensors, TOSA, or StableHLO.
## The `torch` dialect ## The `torch` dialect
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
The backend contract is a normalized form of the `torch` dialect with a set of The backend contract is a normalized form of the `torch` dialect with a set of
properties that make it easy to lower into various forms such as properties that make it easy to lower into various forms such as
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
box. The primary guarantees that we provide Torch-MLIR's backends are: the box. The primary guarantees that we provide Torch-MLIR's backends are:
- All tensors have been converted to value semantics. - All tensors have been converted to value semantics.
- All tensors have at least a known number of dimensions (i.e. rank), and - All tensors have at least a known number of dimensions (i.e. rank), and
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, - [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
`tensor`, etc.) `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo) - [StableHLO](https://github.com/openxla/stablehlo)
### The Linalg Backend (Linalg-on-Tensors) ### The Linalg Backend (Linalg-on-Tensors)
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
- It is extremely solid with static shapes (and many of its users only care - It is extremely solid with static shapes (and many of its users only care
about static shapes, so that's fine). about static shapes, so that's fine).
### The MHLO Backend ### The StableHLO Backend
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
The MHLO backend was the third backend that we added, and it offers a reasonable The StableHLO backend was the third backend that we added, and it offers a
blend of the benefits of the other two. reasonable blend of the benefits of the other two.
- It is a coarse-grained named-op approach. - It is a coarse-grained named-op approach.
- It has a pretty clear spec for most of the ops (with a bit of mental - It has a pretty clear spec for most of the ops (with a bit of mental
translation and hoping that MHLO is the same as HLO): translation and hoping that StableHLO is the same as HLO):
https://www.tensorflow.org/xla/operation_semantics https://www.tensorflow.org/xla/operation_semantics
- It functionally supports dynamic shapes (though not as coherent and consistent - It functionally supports dynamic shapes (though not as coherent and consistent
as Linalg-on-Tensors, and the dynamic shape support falls outside the as Linalg-on-Tensors, and the dynamic shape support falls outside the
@ -317,7 +317,7 @@ blend of the benefits of the other two.
example, TOSA limits (for highly considered reasons) the number of dimensions example, TOSA limits (for highly considered reasons) the number of dimensions
that certain operators can handle to 1D-4D, when from a purely algebraic that certain operators can handle to 1D-4D, when from a purely algebraic
perspective there isn't a good reason to not be more general. Similarly, more perspective there isn't a good reason to not be more general. Similarly, more
general forms of reduction and scatter also fall into MHLO nicely while general forms of reduction and scatter also fall into StableHLO nicely while
TOSA's principles tend to bias it away from that. TOSA's principles tend to bias it away from that.
### Backend Implementation ### Backend Implementation
@ -433,8 +433,9 @@ filling in some corners missing upstream and
to pull together upstream functionality into a working system. to pull together upstream functionality into a working system.
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
ops and lowers them to loops. Note that TOSA and MHLO support lowering to ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend. lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
RefBackend.
The RefBackend is absolutely not suitable for any production use case. It leaks The RefBackend is absolutely not suitable for any production use case. It leaks
memory, doesn't support any error handling, performs no optimizations, and memory, doesn't support any error handling, performs no optimizations, and

View File

@ -34,7 +34,7 @@ and Clang's
- Eric Kunze (@eric-k256) - Eric Kunze (@eric-k256)
- Suraj Sudhir (@sjarus) - Suraj Sudhir (@sjarus)
### TorchToMHLO ### TorchToStablehlo
- Tianyo Kwok (@tanyokwok) - Tianyo Kwok (@tanyokwok)
- Ziheng Jiang (@ZihengJiang) - Ziheng Jiang (@ZihengJiang)

View File

@ -139,7 +139,7 @@ Ex:
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
``` ```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`. Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
## Jupyter ## Jupyter

View File

@ -46,7 +46,7 @@ the ecosystem are:
- The frontend work required to lower TorchScript to the backend contract. - The frontend work required to lower TorchScript to the backend contract.
- The irregular support surface area of the large number of PyTorch ops across - The irregular support surface area of the large number of PyTorch ops across
the Linalg, TOSA, and MHLO backends. the Linalg, TOSA, and StableHLO backends.
Most of this document describes long-term ecosystem changes that will address Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals. these, drastically improving Torch-MLIR's ability to meet its goals.
@ -108,7 +108,7 @@ more advanced).
### Refactoring the backend ### Refactoring the backend
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors, Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
TOSA, and MHLO. These backends take IR in the backend contract form (see TOSA, and StableHLO. These backends take IR in the backend contract form (see
[architecture.md](architecture.md)) and lowers them to the respective dialects. [architecture.md](architecture.md)) and lowers them to the respective dialects.
Today, each backend is implemented completely independently. This leads to Today, each backend is implemented completely independently. This leads to
duplication and irregularity across the backends. duplication and irregularity across the backends.
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
forward-looking efforts that intersect with this effort: forward-looking efforts that intersect with this effort:
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect - [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
initially forked from MHLO which intends to create a stable support surface initially forked from MHLO. MHLO is a fairly complete op set, so it is very
area for what today is our "at head" dependency on MHLO. MHLO is a fairly attractive to have "almost all" models bottleneck through a stable interface
complete op set, so it is very attractive to have "almost all" models like StableHLO. StableHLO is currently under relatively early development,
bottleneck through a stable interface like StableHLO. StableHLO is currently but already delivers on many of the goals of stability.
under relatively early development, but already delivers on many of the goals
of stability.
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect - [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
which could serve a role very similar to MHLO, while providing community which could serve a role very similar to MHLO, while providing community
ownership. TCP is still in early planning phases, but there is strong ownership. TCP is still in early planning phases, but there is strong

View File

@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.configs import ( from torch_mlir_e2e_test.configs import (
LazyTensorCoreTestConfig, LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig, LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig, StablehloBackendTestConfig,
NativeTorchTestConfig, NativeTorchTestConfig,
TorchScriptTestConfig, TorchScriptTestConfig,
TosaBackendTestConfig, TosaBackendTestConfig,
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
) )
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests() register_all_tests()
def _get_argparse(): def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"] config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config", parser.add_argument("-c", "--config",
choices=config_choices, choices=config_choices,
@ -42,7 +42,7 @@ def _get_argparse():
help=f""" help=f"""
Meaning of options: Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend. "linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"mhlo": run through torch-mlir"s default MHLO backend. "stablehlo": run through torch-mlir"s default StableHLO backend.
"tosa": run through torch-mlir"s default TOSA backend. "tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
@ -80,9 +80,9 @@ def main():
if args.config == "tosa": if args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == "mhlo": if args.config == "stablehlo":
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
elif args.config == "native_torch": elif args.config == "native_torch":
config = NativeTorchTestConfig() config = NativeTorchTestConfig()
xfail_set = {} xfail_set = {}

View File

@ -87,8 +87,10 @@ TORCHDYNAMO_XFAIL_SET = {
"StdCorrectionKeepDimModule_basic", "StdCorrectionKeepDimModule_basic",
} }
MHLO_PASS_SET = { STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic", "ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic", "ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic", "ArangeFalsePinMemoryModule_basic",
@ -103,6 +105,7 @@ MHLO_PASS_SET = {
"ArangeStartStepFloatModule_basic", "ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic", "ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic", "ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BmmModule_basic", "BmmModule_basic",
"BroadcastToModule_basic", "BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic", "BroadcastToSameRankStaticModule_basic",
@ -124,12 +127,15 @@ MHLO_PASS_SET = {
"ElementwiseClampMinModule_basic", "ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwiseExpModule_basic", "ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"ElementwiseNegModule_basic", "ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic", "ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic", "ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic", "ElementwiseSqrtModule_basic",
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic", "ElementwiseAddModule_basic",
@ -198,6 +204,8 @@ MHLO_PASS_SET = {
"Gather2DInputModdule_basic", "Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic", "GatherRandomIndexModule_basic",
"GeluBackwardModule_basic", "GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic", "HardTanhIntModule_basic",
"HardTanhModule_basic", "HardTanhModule_basic",
"HardsigmoidModule_basic", "HardsigmoidModule_basic",
@ -220,6 +228,8 @@ MHLO_PASS_SET = {
"MeanDynamicSizesModule_basic", "MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic", "MeanLargeInputModule_basic",
"MeanModule_basic", "MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic", "MmTanhModule_basic",
"Mv_basic", "Mv_basic",
"NativeLayerNormModule4D_basic", "NativeLayerNormModule4D_basic",
@ -251,6 +261,8 @@ MHLO_PASS_SET = {
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic", "Mlp2LayerModuleNoBias_basic",
"NumelModule_basic", "NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim", "SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim", "SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic", "ViewCollapseOnesMiddleModule_basic",
@ -420,6 +432,7 @@ MHLO_PASS_SET = {
"UnsafeViewDynamicExpandModule_basic", "UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic", "AtenRoundIntModule_basic",
"TestF16Return_basic", "TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
} }
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development

View File

@ -1,14 +0,0 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")

View File

@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")

View File

@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
model = BertTinyWrapper() model = BertTinyWrapper()
model.eval() model.eval()
data = torch.randint(30522, (2, 128)) data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module)) outf.write(str(module))
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}") print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else() else()
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif() endif()

View File

@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
} }
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops"; let summary = "Convert Torch ops to Stablehlo ops";
let description = [{ let description = [{
Convert Torch ops to mhlo ops. Convert Torch ops to Stablehlo ops.
}]; }];
let constructor = "mlir::torch::createConvertTorchToMhloPass()"; let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
// Specify any options. // Specify any options.
let options = [ let options = [

View File

@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H #ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H #define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@ -16,10 +16,11 @@
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index); createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H #endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else() else()
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif() endif()

View File

@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
/// TOSA backend contract. /// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled // Do not register the stablehlo options if the stablehlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct MhloBackendPipelineOptions struct StablehloBackendPipelineOptions
: public PassPipelineOptions<MhloBackendPipelineOptions> { : public PassPipelineOptions<StablehloBackendPipelineOptions> {
Option<bool> enableStaticShape{ Option<bool> enableStaticShape{
*this, "enable-static-shape", *this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
llvm::cl::init(false)}; llvm::cl::init(false)};
}; };
void createTorchBackendToMhloBackendPipeline( void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const MhloBackendPipelineOptions &options); OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass(); std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif #endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass(); std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

View File

@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
} }
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> { def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the mhlo backend contract"; let summary = "Verifies conformity to the stablehlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
} }
#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCH_MLIR_ENABLE_STABLEHLO
#endif // TORCHMLIR_TORCHCONVERSION_PASSES #endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -3,13 +3,7 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect) add_subdirectory(Dialect)
add_subdirectory(RefBackend) add_subdirectory(RefBackend)
add_mlir_library(TorchMLIRInitAll set(LinkedLibs
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRFuncDialect MLIRFuncDialect
MLIRIR MLIRIR
MLIRSupport MLIRSupport
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
TorchMLIRRefBackend TorchMLIRRefBackend
) )
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
MhloPasses
MhloToLinalg
StablehloToMhlo
)
endif()
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
${LinkedLibs}
)
torch_mlir_target_includes(TorchMLIRInitAll) torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith) add_subdirectory(TorchToArith)
add_subdirectory(TorchToTosa) add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToMhlo) add_subdirectory(TorchToStablehlo)
endif() endif()
add_subdirectory(TorchToTMTensor) add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram) add_subdirectory(TorchConversionToMLProgram)
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils) TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs list(APPEND linked_libs TorchMLIRTorchToStablehlo)
MhloPasses
TorchMLIRTorchToMhlo)
endif() endif()
add_mlir_library(TorchMLIRConversionPasses add_mlir_library(TorchMLIRConversionPasses

View File

@ -9,15 +9,15 @@
#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/Passes.h"
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "transforms/passes.h" #include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
@ -32,12 +32,4 @@ namespace {
void mlir::torch::registerConversionPasses() { void mlir::torch::registerConversionPasses() {
::registerPasses(); ::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createSymbolicShapeOptimizationPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
} }

View File

@ -1,35 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
MhloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
DEPENDS
MhloDialect
MhloToLinalg
MLIRMhloPassIncGen
LMHLOTransformsPassIncGen
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MhloDialect
MhloToLinalg
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToMhlo)

View File

@ -1,74 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H

View File

@ -7,15 +7,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -29,7 +30,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other, mlir::Value &self, mlir::Value &other,
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
if (selfRank > otherRank) { if (selfRank > otherRank) {
auto unsqueezeDims = auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank)); llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other, auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits); unsqueezeDims, dimSizeIndexBits);
if (failed(unsqueezeInfo)) if (failed(unsqueezeInfo))
return failure(); return failure();
other = *unsqueezeInfo; other = *unsqueezeInfo;
} else if (otherRank > selfRank) { } else if (otherRank > selfRank) {
auto unsqueezeDims = auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank)); llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self, auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
unsqueezeDims, dimSizeIndexBits); dimSizeIndexBits);
if (failed(unsqueezeInfo)) if (failed(unsqueezeInfo))
return failure(); return failure();
self = *unsqueezeInfo; self = *unsqueezeInfo;
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)); /*negative=*/false));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr) return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (elementType.isa<mlir::IntegerType>()) {
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMaxValue(integerType.getWidth())); constType, APInt::getSignedMaxValue(integerType.getWidth()));
} }
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr) return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
return failure(); return failure();
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)); /*negative=*/true));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr) return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (elementType.isa<mlir::IntegerType>()) {
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMinValue(integerType.getWidth())); constType, APInt::getSignedMinValue(integerType.getWidth()));
} }
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr) return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
return failure(); return failure();
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
// These legalizations are for unary ops. // These legalizations are for unary ops.
namespace { namespace {
template <typename AtenOpT, typename MhloOpT> template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> { class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -137,13 +142,13 @@ public:
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfType = self.getType().cast<TensorType>(); auto selfType = self.getType().cast<TensorType>();
if (!selfType) { if (!selfType) {
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
} }
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
self = mhlo::promoteType(rewriter, self, outType); self = hlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self); rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success(); return success();
} }
}; };
@ -152,7 +157,7 @@ public:
// These legalizations are for unary ops with only for floating point datatypes. // These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these. // There is no supported quantized integer mode for these.
namespace { namespace {
template <typename AtenOpT, typename MhloOpT> template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> { class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -164,10 +169,10 @@ public:
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = self.getType().cast<TensorType>();
if (!selfTy) if (!selfTy)
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<MhloOpT>( rewriter.replaceOpWithNewOp<StablehloOpT>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -198,7 +203,7 @@ public:
.template dyn_cast<TensorType>(); .template dyn_cast<TensorType>();
if (!outType) if (!outType)
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) if (!outElemTy.isIntOrFloat())
@ -216,9 +221,9 @@ public:
SmallVector<int32_t> values(size, fillVal); SmallVector<int32_t> values(size, fillVal);
auto constOp = auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value(); hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp); rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
return success(); return success();
} }
}; };
@ -247,8 +252,8 @@ public:
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
lhs = mhlo::promoteType(rewriter, lhs, outTy); lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = mhlo::promoteType(rewriter, rhs, outTy); rhs = hlo::promoteType(rewriter, rhs, outTy);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr); /*broadcast_attr*/ nullptr);
@ -274,7 +279,7 @@ public:
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -287,18 +292,19 @@ public:
} }
if (!rhsType) { if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
if (isa<AtenRsubScalarOp>(op)) { if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs); std::swap(lhs, rhs);
} }
} }
lhs = mhlo::promoteType(rewriter, lhs, outType); lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType); rhs = hlo::promoteType(rewriter, rhs, outType);
if (!skipMultiplyAlpha(op.getAlpha())) { if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha, rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions); bcastDimensions);
@ -328,7 +334,7 @@ public:
TensorType rhsType = rhs.getType().dyn_cast<TensorType>(); TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -343,11 +349,12 @@ public:
if (std::is_same<AtenOpT, AtenSquareOp>()) { if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs; rhs = lhs;
} else if (!rhsType) { } else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
} }
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType); lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType); rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions); rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -368,15 +375,15 @@ public:
if (roundingMode == "trunc") { if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent // "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division. // to C-style integer division.
auto sign = rewriter.create<mhlo::SignOp>(loc, result); auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
auto abs = rewriter.create<mhlo::AbsOp>(loc, result); auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs); auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult(); result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
} }
if (roundingMode == "floor") { if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to // "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator) // floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult(); result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
} }
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
@ -401,7 +408,7 @@ public:
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsTy) if (!lhsTy)
return op.emitError("only Tensor types supported in MHLO"); return op.emitError("only Tensor types supported in StableHLO");
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -414,11 +421,12 @@ public:
} }
if (!rhsTy) { if (!rhsTy) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy); rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
lhsElemTy);
} }
// TODO: what is the PyTorch default type promotion? // TODO: what is the PyTorch default type promotion?
rhs = mhlo::promoteType(rewriter, rhs, lhsTy); rhs = hlo::promoteType(rewriter, rhs, lhsTy);
chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr; chlo::ComparisonDirectionAttr compareDirectionAttr;
@ -485,8 +493,8 @@ public:
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType); Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType); Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -537,8 +545,8 @@ public:
RankedTensorType::get({static_cast<long int>(permValues.size())}, RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
permValues); permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self, rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation); permutation);
return success(); return success();
} }
}; };
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self); rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
return success(); return success();
} }
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
} else { } else {
Value inputRank = rewriter.create<arith::ConstantOp>( Value inputRank = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
inputRank);
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(), dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getIndexType(), dim); rewriter.getIndexType(), dim);
} }
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
template <> template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, AtenWhereSelfOp op, OpAdaptor adaptor,
OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter& rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Value cond = adaptor.getCondition(); Value cond = adaptor.getCondition();
Value other = adaptor.getOther(); Value other = adaptor.getOther();
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
return op.emitError("failed broadcast other and condition ranks"); return op.emitError("failed broadcast other and condition ranks");
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>( rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
op, op, getTypeConverter()->convertType(op.getType()),
getTypeConverter()->convertType(op.getType()),
ArrayRef<Value>{cond, self, other}); ArrayRef<Value>{cond, self, other});
return success(); return success();
} }
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
.cast<RankedTensorType>(); .cast<RankedTensorType>();
if (options.enableStaticShape && selfTy.hasStaticShape()) { if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp); rewriter.replaceOp(op, bcastOp);
return success(); return success();
} }
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op->getLoc(), ValueRange{bcastShapeVec}); op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers = auto dimensionNumbers =
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank)); llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor, op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers)); rewriter.getI64TensorAttr(dimensionNumbers));
} }
@ -708,8 +715,8 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
RankedTensorType::get({static_cast<long int>(permValues.size())}, RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
permValues); permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self, rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation); permutation);
return success(); return success();
} }
@ -721,7 +728,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<mhlo::TanhOp>( rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} else { } else {
@ -751,16 +758,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue()); return APInt(bitWidth, v.getSExtValue());
}); });
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr); rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
valueAttr);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
adaptor.getValue()); adaptor.getValue());
return success(); return success();
} }
// AtenReciprocalOp // AtenReciprocalOp
// Reciprocal(x) = Div(1, x) // Reciprocal(x) = Div(1, x)
template <> template <>
@ -777,7 +784,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
} }
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input); rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success(); return success();
} }
@ -790,9 +797,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
auto outputElemType = outputType.getElementType(); auto outputElemType = outputType.getElementType();
Value mhloTensor = Value stablehloTensor = hlo::scalarToStablehloTensor(
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType); rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, mhloTensor); rewriter.replaceOp(op, stablehloTensor);
return success(); return success();
} }
@ -815,7 +822,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success(); return success();
} }
// AtenReluOp // AtenReluOp
// Relu(x) = Max(0, x) // Relu(x) = Max(0, x)
template <> template <>
@ -836,11 +842,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false), false),
lhs); lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor); rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
return success(); return success();
} }
// Convert a Aten::GELU to HLO // Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
template <> template <>
@ -857,12 +862,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two); auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo); auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement); auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one); auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half); auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul); rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
return success(); return success();
} }
@ -881,7 +886,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success(); return success();
} }
// AtenBatchNormOp // AtenBatchNormOp
template <> template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
@ -919,28 +923,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelShape = rewriter.create<tensor::FromElementsOp>( Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim}); op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) { if (failed(checkNotNone(rewriter, op, weight))) {
weight = mhlo::getConstantOfShape( weight = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, bias))) { if (failed(checkNotNone(rewriter, op, bias))) {
bias = mhlo::getConstantOfShape( bias = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, runningVar))) { if (failed(checkNotNone(rewriter, op, runningVar))) {
runningVar = mhlo::getConstantOfShape( runningVar = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, runningMean))) { if (failed(checkNotNone(rewriter, op, runningMean))) {
runningMean = mhlo::getConstantOfShape( runningMean = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
@ -983,10 +987,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Type outputTy = getTypeConverter()->convertType(op.getType()); Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy = Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>( auto batchNormTrainingResult =
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, rewriter.create<stablehlo::BatchNormTrainingOp>(
weight, bias, rewriter.getF32FloatAttr(eps), op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
rewriter.getI64IntegerAttr(1)); weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
return success(); return success();
} else { } else {
@ -995,10 +1000,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
inputTy.getShape().end()}; inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0]; castShape[1] = weightTy.getShape()[0];
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
// Feature counts must match among operands of mhlo::BatchNormInferenceOp. // Feature counts must match among operands of
// stablehlo::BatchNormInferenceOp.
Value inputCasted = Value inputCasted =
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input); rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
Value output = rewriter.create<mhlo::BatchNormInferenceOp>( Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar, runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute. // 'epsilon' must satisfy constraint: 32-bit float attribute.
@ -1008,7 +1014,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
} }
} }
// AtenNativeLayerNormOp // AtenNativeLayerNormOp
template <> template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
@ -1076,21 +1081,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
} }
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize, SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
numEmbeddingDimSize}; numEmbeddingDimSize};
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize}; SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
auto mhloBatchNormOutTy = auto stablehloBatchNormOutTy =
RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
auto mhloBathNormOutMeanOrVarTy = auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); meanOrVarStablehloOutShape, inputTy.getElementType());
// Reshape input // Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>( auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input, op->getLoc(), stablehloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())}) {static_cast<int64_t>(inputFlattenShape.size())})
.value()); .value());
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec( SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType() numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>() .cast<mlir::FloatType>()
@ -1103,16 +1108,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto oneOrZeroConstType = auto oneOrZeroConstType =
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
Value scale = rewriter.create<mhlo::ConstantOp>( Value scale = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType, op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
Value offset = rewriter.create<mhlo::ConstantOp>( Value offset = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType, op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>( auto batchNormTrainingResult =
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, rewriter.create<stablehlo::BatchNormTrainingOp>(
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, op->getLoc(), stablehloBatchNormOutTy,
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
// Reshape back // Reshape back
auto outputTy = auto outputTy =
@ -1120,36 +1127,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto outputMeanOrVarTy = auto outputMeanOrVarTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
auto output = rewriter.create<mhlo::DynamicReshapeOp>( auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
mhlo::getConstTensor(rewriter, op, outputTy.getShape(), hlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())}) {static_cast<int64_t>(outputTy.getShape().size())})
.value()); .value());
auto mean = rewriter.create<mhlo::DynamicReshapeOp>( auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
mhlo::getConstTensor( hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value()); .value());
auto var = rewriter.create<mhlo::DynamicReshapeOp>( auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
mhlo::getConstTensor( hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value()); .value());
// Apply affine transform: output x weight + bias [element-wise] // Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto outputMulWeight = auto outputMulWeight =
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight); rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput = auto finalOuput = rewriter.create<stablehlo::AddOp>(
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias); op->getLoc(), outputMulWeight, bcastedBias);
rewriter.replaceOp(op, {finalOuput, mean, var}); rewriter.replaceOp(op, {finalOuput, mean, var});
return success(); return success();
} }
// AtenCatOp // AtenCatOp
template <> template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
@ -1173,11 +1179,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type // Promote type
for (auto &v : builtinTensors) { for (auto &v : builtinTensors) {
v = mhlo::promoteType(rewriter, v, outType); v = hlo::promoteType(rewriter, v, outType);
} }
size_t posDim = toPositiveDim(dim, outType.getRank()); size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>( rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, outType, ValueRange(builtinTensors), posDim); op, outType, ValueRange(builtinTensors), posDim);
return success(); return success();
} }
@ -1225,7 +1231,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none"); op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(checkNotNone(rewriter, op, minValue))) { } else if (failed(checkNotNone(rewriter, op, minValue))) {
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) { if (failed(minInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1233,7 +1240,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
} }
minValue = *minInfo; minValue = *minInfo;
} else if (failed(checkNotNone(rewriter, op, maxValue))) { } else if (failed(checkNotNone(rewriter, op, maxValue))) {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); minValue =
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) { if (failed(maxInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1241,10 +1249,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
} }
maxValue = *maxInfo; maxValue = *maxInfo;
} else { } else {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); minValue =
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
} }
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue); rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
return success(); return success();
} }
@ -1266,24 +1277,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: only int or float dtype supported"); op, "unimplemented: only int or float dtype supported");
} }
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype); Value start =
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype); hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype); Value end =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
// Get length of the 1-d output tensor // Get length of the 1-d output tensor
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start); Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step); Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
Value resultLength = rewriter.create<mhlo::ReshapeOp>( Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut); loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) { if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength); resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<mhlo::ConvertOp>( resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
} }
Value window = Value window =
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0); rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions; DenseIntElementsAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step, Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions); broadcastDimensions);
@ -1298,9 +1312,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto outType = this->getTypeConverter() auto outType =
->convertType(op.getType()) this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
.cast<TensorType>();
if (!outType) { if (!outType) {
return op.emitError("only tensor type is supported"); return op.emitError("only tensor type is supported");
} }
@ -1320,26 +1333,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute // Compute
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0); Value kBeta0 =
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half); rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value erfArg = Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf()); Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
adaptor.getSelf());
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg); Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one); Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half); Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<mhlo::MulOp>( Value inputSquared = rewriter.create<stablehlo::MulOp>(
loc, outType, adaptor.getSelf(), adaptor.getSelf()); loc, outType, adaptor.getSelf(), adaptor.getSelf());
Value negHalfInputSquared = Value negHalfInputSquared =
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf); rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
Value expRes = Value expRes =
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared); rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes); Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
Value pdfTimesInput = Value pdfTimesInput =
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf()); rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
Value pdfTimesInputAddCdf = Value pdfTimesInputAddCdf =
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf); rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(), rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
pdfTimesInputAddCdf); op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
return success(); return success();
} }
@ -1366,9 +1380,9 @@ public:
}; };
} // namespace } // namespace
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>(); target.addIllegalOp<AtenTransposeIntOp>();
@ -1376,23 +1390,24 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
target.addIllegalOp<RuntimeAssertOp>(); target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context); patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \ #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context) patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp); INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
#undef INSERT_UNARY_PATTERN #undef INSERT_UNARY_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context) patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); context)
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp); INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
#undef INSERT_UNARY_FPONLY_PATTERN #undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
@ -1482,10 +1497,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \ patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
context) typeConverter, context)
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);

View File

@ -0,0 +1,29 @@
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)

View File

@ -7,14 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -24,7 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
namespace { namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t, 4> startIndexMap(1, axis); SmallVector<int64_t, 4> startIndexMap(1, axis);
// indexVecDim // indexVecDim
int64_t indexVecDim = indicesRank; int64_t indexVecDim = indicesRank;
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*offsetDims=*/offsetDims, /*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedSliceDims, /*collapsedSliceDims=*/collapsedSliceDims,
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
auto outputTy = auto outputTy =
RankedTensorType::get(outputShape, inputRankTy.getElementType()); RankedTensorType::get(outputShape, inputRankTy.getElementType());
return rewriter return rewriter
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices, .create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr) sliceSizesTensor, dimsAttr)
.getResult(); .getResult();
} }
} // namespace } // namespace
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // Ref:
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional) // padding_idx (int, optional)
// If specified, the entries at padding_idx do not contribute to the gradient; // If specified, the entries at padding_idx do not contribute to the
// therefore, the embedding vector at padding_idx is not updated during training, // gradient; therefore, the embedding vector at padding_idx is not updated
// i.e. it remains as a fixed “pad”. // during training, i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional) // scale_grad_by_freq (boolean, optional)
// If given, this will scale gradients by the inverse of frequency of the // If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False. // words in the mini-batch. Default False.
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis( Value output = gatherTensorAlongSingleAxis(
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
return success(); return success();
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis( Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
return success(); return success();
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
auto options = getOptions(); auto options = getOptions();
auto indexShapeInfo = auto indexShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
if (failed(indexShapeInfo)) { if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param"); op, "failed to get dim sizes of `index` param");
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
SmallVector<Value> toConcat; SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) { for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) { if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>( toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
loc, toConcatIndexType, index, toConcatIndexShape)); loc, toConcatIndexType, index, toConcatIndexShape));
} else { } else {
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>( toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
loc, toConcatIndexType, toConcatIndexShape, loc, toConcatIndexType, toConcatIndexShape,
rewriter.getI64IntegerAttr(i))); rewriter.getI64IntegerAttr(i)));
} }
} }
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>( auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank())); loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1); SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
startIndexMap.push_back(i); startIndexMap.push_back(i);
} }
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*offsetDims=*/{}, /*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims, /*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap, /*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim); /*indexVecDim=*/indexVecDim);
rewriter.replaceOpWithNewOp<mhlo::GatherOp>( rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr, op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes)); rewriter.getI64TensorAttr(sliceSizes));
return success(); return success();
} }
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
namespace { namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
ArrayRef<int64_t> broadcastDims) { ArrayRef<int64_t> broadcastDims) {
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>(); auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes); Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
RankedTensorType outTy = RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType()); RankedTensorType::get(shape, tensorTy.getElementType());
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
rewriter.getIntegerType(64)); rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>( auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, mhloShape, broadcastAttr); loc, outTy, tensor, stablehloShape, broadcastAttr);
return broadcast; return broadcast;
} }
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
ArrayRef<int64_t> inpTransDims) { ArrayRef<int64_t> inpTransDims) {
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank(); auto rank = inputTy.getRank();
auto transDims = mhlo::toPositiveDims(inpTransDims, rank); auto transDims = hlo::toPositiveDims(inpTransDims, rank);
auto inpShape = inputTy.getShape(); auto inpShape = inputTy.getShape();
std::vector<int64_t> newShape; std::vector<int64_t> newShape;
newShape.reserve(rank); newShape.reserve(rank);
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input, auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
permuteAttr); input, permuteAttr);
return result.getResult(); return result.getResult();
} }
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
} }
// set result dimensions // set result dimensions
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) { if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
lhsResultDim >= 0) {
outShape.push_back(lhsShape[lhsResultDim]); outShape.push_back(lhsShape[lhsResultDim]);
} }
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) { if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
rhsResultDim >= 0) {
outShape.push_back(rhsShape[rhsResultDim]); outShape.push_back(rhsShape[rhsResultDim]);
} }
return RankedTensorType::get(outShape, lhsTy.getElementType()); return RankedTensorType::get(outShape, lhsTy.getElementType());
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(), std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank); rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor( auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
rewriter, op, rhs, leadingDims, dimSizeIndexBits); dimSizeIndexBits);
auto lhsDimSizes = auto lhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end()); lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(), std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank); lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor( auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
rewriter, op, lhs, leadingDims, dimSizeIndexBits); dimSizeIndexBits);
auto rhsDimSizes = auto rhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end()); rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -218,8 +221,8 @@ public:
if (lhsRank <= 2 && rhsRank <= 2) { if (lhsRank <= 2 && rhsRank <= 2) {
auto tensorType = auto tensorType =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()); ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs, output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
nullptr); rhs, nullptr);
return success(); return success();
} }
@ -253,8 +256,8 @@ public:
lhsContractingDim = nBatchDims; lhsContractingDim = nBatchDims;
} }
mhlo::DotDimensionNumbersAttr dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get( stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims, /*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims,
@ -264,8 +267,8 @@ public:
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
output = rewriter output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs, .create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr) dotDimensionNumbers, nullptr)
.getResult(); .getResult();
return success(); return success();
} }
@ -312,7 +315,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in MHLO matmul"); "only ranked tensor types are supported in StableHLO matmul");
return success(); return success();
} }
@ -335,7 +338,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in MHLO matmul"); "only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -371,7 +374,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in MHLO matmul"); "only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -398,10 +401,10 @@ public:
auto bias = adaptor.getBias(); auto bias = adaptor.getBias();
auto biasTy = bias.getType(); auto biasTy = bias.getType();
// MHLO does not mandate that elementwise op tensors need to be ranked. // StableHLO does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() && if (!biasTy.template isa<Torch::NoneType>() &&
!biasTy.template isa<RankedTensorType>()) !biasTy.template isa<RankedTensorType>())
return op.emitError("only ranked tensor types are supported in MHLO " return op.emitError("only ranked tensor types are supported in StableHLO "
"matmul for bias tensor"); "matmul for bias tensor");
// weight.T // weight.T
@ -427,14 +430,14 @@ public:
auto outTy = auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get( stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims, /*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim}, /*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>( Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
@ -464,7 +467,7 @@ public:
auto weightElemTy = weightTy.getElementType(); auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank(); auto rank = weightTy.getRank();
const auto &options = getOptions(); const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor( SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits); rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank); SmallVector<int64_t> weightShapeInt(rank);
@ -488,7 +491,7 @@ public:
} }
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec); op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>( weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor); weight, weightShapeTensor);
@ -497,7 +500,7 @@ public:
for (int64_t i = 0; i <= rank; i++) for (int64_t i = 0; i <= rank; i++)
transposeDims[i] = i; transposeDims[i] = i;
std::swap(transposeDims[1], transposeDims[0]); std::swap(transposeDims[1], transposeDims[0]);
weight = rewriter.create<mhlo::TransposeOp>( weight = rewriter.create<stablehlo::TransposeOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
@ -509,7 +512,7 @@ public:
weightShapeVec[1] = OCMulGValue; weightShapeVec[1] = OCMulGValue;
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec); op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>( weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor); weight, weightShapeTensor);
return weight; return weight;
@ -544,25 +547,27 @@ public:
} }
// Prepare for transposed convolution // Prepare for transposed convolution
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1); SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec); DenseIntElementsAttr stablehloStride =
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0); rewriter.getI64TensorAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) { for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
mhloPaddingVec[i * 2] = padInt; stablehloPaddingVec[i * 2] = padInt;
mhloPaddingVec[i * 2 + 1] = padInt; stablehloPaddingVec[i * 2 + 1] = padInt;
} }
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
mhloPaddingVec); stablehloPaddingVec);
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims); SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin()); std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr mhloLhsDilation = DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(mhloLhsDilationVec); rewriter.getI64TensorAttr(stablehloLhsDilationVec);
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims); SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin()); std::copy(dilation.begin(), dilation.end(),
DenseIntElementsAttr mhloRhsDilation = stablehloRhsDilationVec.begin());
rewriter.getI64TensorAttr(mhloRhsDilationVec); DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
DenseElementsAttr windowReversal; DenseElementsAttr windowReversal;
ArrayAttr precisionConfig; ArrayAttr precisionConfig;
@ -571,8 +576,8 @@ public:
for (int i = 0; i < nSpatialDims; ++i) { for (int i = 0; i < nSpatialDims; ++i) {
spatialDims.push_back(i + 2); spatialDims.push_back(i + 2);
} }
mhlo::ConvDimensionNumbersAttr dimensionNumbers = stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get( stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1, /*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDims, /*inputSpatialDimensions=*/spatialDims,
@ -583,17 +588,18 @@ public:
/*outputSpatialDimensions=*/spatialDims); /*outputSpatialDimensions=*/spatialDims);
// Reverse and transpose weight // Reverse and transpose weight
weight = rewriter.create<mhlo::ReverseOp>( weight = rewriter.create<stablehlo::ReverseOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
if (groups != 1) { if (groups != 1) {
weight = reshapeConvWeight(rewriter, op, weight, groups); weight = reshapeConvWeight(rewriter, op, weight, groups);
} }
// Create transposed convolution // Create transposed convolution
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>( auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding, op->getLoc(), convOutTy, input, weight, stablehloStride,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
static_cast<uint64_t>(groups), 1, precisionConfig); windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
// Handle output padding // Handle output padding
if (!needHandleOutputPadding) { if (!needHandleOutputPadding) {
@ -605,8 +611,8 @@ public:
std::copy(outputPadding.begin(), outputPadding.end(), std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2); edgePaddingHighVec.begin() + 2);
Value paddingValue = Value paddingValue =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value(); hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy); paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow = mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec); rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh = mlir::DenseIntElementsAttr edgePaddingHigh =
@ -614,7 +620,7 @@ public:
mlir::DenseIntElementsAttr interiorPadding = mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec); rewriter.getI64VectorAttr(interiorPaddingVec);
auto paddedOutput = rewriter.create<mhlo::PadOp>( auto paddedOutput = rewriter.create<stablehlo::PadOp>(
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
edgePaddingHigh, interiorPadding); edgePaddingHigh, interiorPadding);
@ -628,22 +634,22 @@ public:
ArrayRef<int64_t> dilation, int64_t groups) const { ArrayRef<int64_t> dilation, int64_t groups) const {
int64_t nDims = outType.getRank(); int64_t nDims = outType.getRank();
// Get mhlo::ConvolutionOp attributes // Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())}, RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stride); stride);
std::vector<int64_t> mhloPaddingVec; std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) { for (size_t i = 0; i < padding.size(); i++) {
mhloPaddingVec.emplace_back(padding[i]); stablehloPaddingVec.emplace_back(padding[i]);
mhloPaddingVec.emplace_back(padding[i]); stablehloPaddingVec.emplace_back(padding[i]);
} }
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<long int>(padding.size()), static_cast<long int>(2)}, {static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPaddingVec); stablehloPaddingVec);
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())}, RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
dilation); dilation);
@ -651,8 +657,8 @@ public:
for (int64_t i = 2; i < nDims; i++) { for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i); spatialDimensions.emplace_back(i);
} }
mhlo::ConvDimensionNumbersAttr dimensionNumbers = stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get( stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1, /*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDimensions, /*inputSpatialDimensions=*/spatialDimensions,
@ -662,17 +668,18 @@ public:
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDimensions); /*outputSpatialDimensions=*/spatialDimensions);
// mhlo::ConvolutionOp's optional attributes, leave them as default // stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr mhloLhsDilation; DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal; DenseElementsAttr windowReversal;
ArrayAttr precisionConfig; ArrayAttr precisionConfig;
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>( auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding, op->getLoc(), outType, input, weight, stablehloWindowStride,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
static_cast<uint64_t>(groups), 1, precisionConfig); windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
return mhloConvOp.getResult(); return stablehloConvOp.getResult();
} }
LogicalResult LogicalResult
@ -754,21 +761,22 @@ public:
} }
} }
Value mhloConvResult; Value stablehloConvResult;
if (transposed) { if (transposed) {
mhloConvResult = convertTransposedConv( stablehloConvResult = convertTransposedConv(
op, rewriter, outTy, input, weight, stride, padding, dilation, op, rewriter, outTy, input, weight, stride, padding, dilation,
outputPadding, groups, needHandleOutputPadding); outputPadding, groups, needHandleOutputPadding);
} else { } else {
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, stablehloConvResult =
stride, padding, dilation, groups); convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
dilation, groups);
} }
auto bias = adaptor.getBias(); auto bias = adaptor.getBias();
// No bias provided // No bias provided
if (failed(checkNotNone(rewriter, op, op.getBias()))) { if (failed(checkNotNone(rewriter, op, op.getBias()))) {
rewriter.replaceOp(op, mhloConvResult); rewriter.replaceOp(op, stablehloConvResult);
return success(); return success();
} }
@ -790,21 +798,21 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0)); llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
const auto &options = getOptions(); const auto &options = getOptions();
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy); bias = hlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult, rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
bias, bcastDimensions); op, outTy, stablehloConvResult, bias, bcastDimensions);
return success(); return success();
} }
}; };
} // namespace } // namespace
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,7 +29,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
@ -40,14 +41,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest( constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
op->emitError("unimplemented lowering in AtenPoolingOp"); op->emitError("unimplemented lowering in AtenPoolingOp");
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> mhloStride(inputRank, 1); SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1); SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1); SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0); SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(), std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2); stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2); stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloKernelSize); stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloStride); stablehloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloDilation); stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPadding); stablehloPadding);
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad); baseDilations, windowDilations, pad);
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value result = Value result =
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
} }
rewriter.replaceOp(op, reduceWindowOp.getResults()); rewriter.replaceOp(op, reduceWindowOp.getResults());
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> mhloStride(inputRank, 1); SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1); SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1); SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0); SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(), std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2); stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2); stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloKernelSize); stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloStride); stablehloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloDilation); stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPadding); stablehloPadding);
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto initIndexTensor = auto initIndexTensor =
rewriter rewriter
.create<mhlo::DynamicIotaOp>( .create<stablehlo::DynamicIotaOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(initIndexShapeForType, RankedTensorType::get(initIndexShapeForType,
rewriter.getI64Type()), rewriter.getI64Type()),
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto indexTensor = auto indexTensor =
rewriter rewriter
.create<mhlo::DynamicReshapeOp>( .create<stablehlo::DynamicReshapeOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()), RankedTensorType::get(inputShape, rewriter.getI64Type()),
initIndexTensor, inputShapeTensor) initIndexTensor, inputShapeTensor)
.getResult(); .getResult();
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value(); Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad); windowDimensions, windowStrides, baseDilations, windowDilations, pad);
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondValArg = std::next(firstIdxArg); auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr; stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT); rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED); rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
} }
mhlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), stablehlo::ComparisonDirectionAttr::get(
mhlo::ComparisonDirection::GE); rewriter.getContext(), stablehlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), stablehlo::ComparisonDirectionAttr::get(
mhlo::ComparisonDirection::EQ); rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>( Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr); compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>( Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg); op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal. // Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>( Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr); compareEqDirectionAttr, compareTypeAttr);
Value minIdx = Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg); *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>( Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>( Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal); op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>( rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
} }
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> mhloStride(inputRank, 1); SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1); SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1); SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0); SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2); stablehloKernelSize.begin() + inputRank - 2);
mhloPadding[mhloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloKernelSize); stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloStride); stablehloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloDilation); stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPadding); stablehloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad); baseDilations, windowDilations, pad);
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock); rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult = Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
} }
// Use kernel size as the divisor // Use kernel size as the divisor
if (countIncludePad) { if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>( Value divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value(); .value();
divisor = mhlo::promoteType(rewriter, divisor, outTy); divisor = hlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>( rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success(); return success();
} }
// Use another mhlo.ReduceWindowOp to get the divisor // Use another stablehlo.ReduceWindowOp to get the divisor
Value windowSizeConst = Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value(); hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeVec = auto inputShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>( windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), RankedTensorType::get(outShape, inputElemTy), op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
windowDilations, pad); windowDilations, pad);
@ -522,11 +526,11 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sizeBlock); rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult = Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
} }
rewriter.replaceOpWithNewOp<mhlo::DivOp>( rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success(); return success();
} }
@ -560,33 +564,33 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
SmallVector<int64_t> mhloKernelSize(inputRank, 1); SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
mhloKernelSize[dim] = inputShape[dim]; stablehloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> mhloStride(inputRank, 1); SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1); SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0); SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
mhloPadding[dim * 2] = inputShape[dim] - 1; stablehloPadding[dim * 2] = inputShape[dim] - 1;
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloKernelSize); stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloStride); stablehloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloDilation); stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPadding); stablehloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>( auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad); baseDilations, windowDilations, pad);
@ -604,17 +608,17 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock); rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult = Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
} }
rewriter.replaceOp(op, reduceWindowSum.getResults()); rewriter.replaceOp(op, reduceWindowSum.getResults());
return success(); return success();
} }
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>(); target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options); patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);

View File

@ -0,0 +1,69 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_stablehlo {
struct TorchToStablehloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToStablehloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToStablehloOptions &getOptions() const { return options; }
private:
TorchToStablehloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
} // namespace torch_to_stablehlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H

View File

@ -7,14 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy, static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest( constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
return std::nullopt; return std::nullopt;
Value initIndex; Value initIndex;
if (dimSizeIndexBits == 32) { if (dimSizeIndexBits == 32) {
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value(); initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else { } else {
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value(); initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
} }
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>( auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputShape, RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)), rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim)); inputShapeTensor, static_cast<uint64_t>(dim));
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>( auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), ValueRange{input, indexTensor}, op->getLoc(), ValueRange{input, indexTensor},
ValueRange{ ValueRange{
initValue, initValue,
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
}, },
dimensions); dimensions);
Block &block = mhloReduceOp.getBody().emplaceBlock(); Block &block = stablehloReduceOp.getBody().emplaceBlock();
// Add block arguments // Add block arguments
auto blockValArgumentType = auto blockValArgumentType =
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondValArg = std::next(firstIdxArg); auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr; stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT); rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get( compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED); rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
} }
mhlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), stablehlo::ComparisonDirectionAttr::get(
mhlo::ComparisonDirection::GE); rewriter.getContext(), stablehlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), stablehlo::ComparisonDirectionAttr::get(
mhlo::ComparisonDirection::EQ); rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>( Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr); compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>( Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg); op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// get smaller index value if compared nums are equal. // get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>( Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr); compareEqDirectionAttr, compareTypeAttr);
Value minIdx = Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg); *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>( Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>( Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal); op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>( rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
} }
return mhloReduceOp.getResults(); return stablehloReduceOp.getResults();
} }
namespace { namespace {
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenArgmaxOp to MHLO"); "AtenArgmaxOp to StableHLO");
} }
int64_t dim; int64_t dim;
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
options.dimSizeIndexBits) dim, options.dimSizeIndexBits)
.value(); .value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), mhloReduceResults[1], op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
outShapeTensor); outShapeTensor);
return success(); return success();
} }
rewriter.replaceOp(op, mhloReduceResults[1]); rewriter.replaceOp(op, stablehloReduceResults[1]);
return success(); return success();
} }
} // namespace } // namespace
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxDimOp to MHLO"); "AtenMaxDimOp to StableHLO");
} }
RankedTensorType valResultType = getTypeConverter() RankedTensorType valResultType = getTypeConverter()
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
options.dimSizeIndexBits) dim, options.dimSizeIndexBits)
.value(); .value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>( auto stablehloReduceValueResult =
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor); rewriter.create<stablehlo::DynamicReshapeOp>(
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>( op->getLoc(), valResultType, stablehloReduceResults[0],
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor); outShapeTensor);
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult}); auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success(); return success();
} }
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]}); rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success(); return success();
} }
} // namespace } // namespace
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
} }
if (inputTy.getElementType() != outTy.getElementType()) { if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type. // Use output element type as computation type.
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy); input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = input.getType().dyn_cast<RankedTensorType>();
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumOp to MHLO"); "AtenSumOp to StableHLO");
} }
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
} }
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure(); if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>( auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock(); Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>( Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults()); stablehloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to MHLO"); "AtenMaxOp to StableHLO");
} }
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure(); if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>( auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock(); Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<mhlo::MaxOp>( Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults()); stablehloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
} }
if (inputTy.getElementType() != outTy.getElementType()) { if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type. // Use output element type as computation type.
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy); input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = input.getType().dyn_cast<RankedTensorType>();
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumDimIntListOp to MHLO"); "AtenSumDimIntListOp to StableHLO");
} }
SmallVector<int64_t> inputDims; SmallVector<int64_t> inputDims;
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure(); if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>( auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Region &region = mhloReduceOp.getBody(); Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock(); Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>( Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
} }
if (keepDim) { if (keepDim) {
const auto &options = getOptions(); const auto &options = getOptions();
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, auto outShapeInfo =
options.dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) { if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResult(0), outShapeTensor); stablehloReduceOp.getResult(0), outShapeTensor);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults()); stablehloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
// AtenFrobeniusNormDimOp // AtenFrobeniusNormDimOp
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims) // aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
// + mhlo.sqrt // dims)
// + stablehlo.sqrt
namespace { namespace {
template <> template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite( LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor, AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
const TorchToMhloOptions &options = getOptions(); const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>(); auto inputType = input.getType().dyn_cast<RankedTensorType>();
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
} }
} }
// Sort the dims in ascending order, making the conversion // Sort the dims in ascending order, making the conversion
// stable with unordered dims. // stable with unordered dims.
std::sort(dims.begin(), dims.end()); std::sort(dims.begin(), dims.end());
@ -624,14 +642,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported"); op, "non-const bool `keepdim` is not supported");
} }
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input); auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) { if (!initValue) {
return failure(); return failure();
} }
auto reduceOp = rewriter.create<mhlo::ReduceOp>( auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue, op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims)); rewriter.getI64TensorAttr(dims));
@ -649,30 +667,32 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument, auto addResult = rewriter.create<stablehlo::AddOp>(
secondArgument); op->getLoc(), firstArgument, secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult()); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
} }
auto output = auto output =
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0)); rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
if (keepDim) { if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) { if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto outShapeVec = *outShapeInfo; auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>( auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr( op->getLoc(),
rewriter.getIntegerType(options.dimSizeIndexBits), 1)); rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) { for (int64_t i : dims) {
outShapeVec[i] = one; outShapeVec[i] = one;
} }
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output, op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor); outShapeTensor);
return success(); return success();
@ -682,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
} }
} // namespace } // namespace
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \

View File

@ -7,11 +7,12 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "./MhloLegalizeUtils.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric> #include <numeric>
@ -21,27 +22,27 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
namespace mlir { namespace mlir {
namespace mhlo { namespace hlo {
// Create a 32-bit float constant operator from a float // Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) { float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val); auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
// Create a 64-bit float constant operator from a double // Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) { double val) {
auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, val); auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
shape, rewriter.getIntegerType(vec[0].getBitWidth())); shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) { T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype); auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val); auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op = auto const_op = rewriter.create<stablehlo::ConstantOp>(
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr); op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value scalarValue, Type dtype) { Operation *op, Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>( auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue}); op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor = auto dtype_tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype); rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<mhlo::ReshapeOp>( return rewriter.create<stablehlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype), op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor); dtype_tensor);
} }
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType = TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType()); in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input); return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
input);
} }
return input; return input;
} }
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type = TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType()); in_type.cloneWith(in_type.getShape(), outType.getElementType());
input = input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input); input);
} }
ArrayRef<int64_t> inShape = in_type.getShape(); ArrayRef<int64_t> inShape = in_type.getShape();
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
bcastDims); bcastDims);
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
input, bcast_attr); op->getLoc(), outType, input, bcast_attr);
return bcast_op.getResult(); return bcast_op.getResult();
} }
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
} }
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes); auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape) return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult(); .getResult();
} }
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType) { TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr); auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
return rewriter return rewriter
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape, .create<stablehlo::DynamicBroadcastInDimOp>(
rewriter.getI64TensorAttr({})) loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
.getResult(); .getResult();
} }
} // namespace mhlo } // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H #ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@ -18,22 +18,22 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace hlo {
using mlir::ConversionPatternRewriter; using mlir::ConversionPatternRewriter;
// Create a 32-bit float constant operator from a float // Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val); float val);
// Create a 64-bit float constant operator from a double // Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val); double val);
// Templated function to create a constant op for given type and shape. // Templated function to create a constant op for given type and shape.
// T: storage C type. // T: storage C type.
// Default template creates a constant tensor in T. // Default template creates a constant tensor in T.
// To create INT48 MHLO constant, need to pass in llvm::APInt instead. // To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
template <typename T> template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op, std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape); ArrayRef<T> vec, ArrayRef<int64_t> shape);
@ -42,8 +42,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape); T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value scalarValue, Type dtype); Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType); TensorType outType);
} // namespace mhlo } // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H #endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H

View File

@ -7,17 +7,18 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
namespace { namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> { class ConvertTorchToStablehlo
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
public: public:
ConvertTorchToMhlo() = default; ConvertTorchToStablehlo() = default;
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) { ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape; this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index; this->enableI32Index = enableI32Index;
} }
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>(); registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>(); registry.insert<stablehlo::StablehloDialect>();
registry.insert<tensor::TensorDialect>(); registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>(); registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry); TorchConversion::getBackendTypeConversionDependentDialects(registry);
@ -48,7 +50,7 @@ public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
tensor::TensorDialect, arith::ArithDialect>(); tensor::TensorDialect, arith::ArithDialect>();
TypeConverter typeConverter; TypeConverter typeConverter;
@ -57,20 +59,20 @@ public:
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape, torch_to_stablehlo::TorchToStablehloOptions options{
enableI32Index ? 32u : 64u}; enableStaticShape, enableI32Index ? 32u : 64u};
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, torch_to_stablehlo::populateBasicOpPatternsAndLegality(
target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
target, options); typeConverter, patterns, target, options);
torch_to_mhlo::populateReductionOpPatternsAndLegality( torch_to_stablehlo::populateGatherOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target, options);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
@ -82,13 +84,13 @@ public:
} // namespace } // namespace
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() { mlir::torch::createConvertTorchToStablehloPass() {
return std::make_unique<ConvertTorchToMhlo>(false, false); return std::make_unique<ConvertTorchToStablehlo>(false, false);
} }
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape, mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
bool enableI32Index) { bool enableI32Index) {
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape, return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
enableI32Index); enableI32Index);
} }

View File

@ -7,14 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "./MhloLegalizeUtils.h" #include "PopulatePatterns.h"
#include "./PopulatePatterns.h" #include "StablehloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,7 +29,7 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo; using namespace mlir::torch::torch_to_stablehlo;
namespace { namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize]. // A dimension index from torch.dialect might outside the range [0, dimSize].
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor = auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult(); rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
return rewriter.create<mhlo::RealDynamicSliceOp>( return rewriter.create<stablehlo::RealDynamicSliceOp>(
loc, outTy, input, startTensor, endTensor, stridesTensor); loc, outTy, input, startTensor, endTensor, stridesTensor);
} }
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
step = rewriter.create<arith::TruncIOp>(loc, intType, step); step = rewriter.create<arith::TruncIOp>(loc, intType, step);
} }
FailureOr<SmallVector<Value, 4>> dimSizesInfo = FailureOr<SmallVector<Value, 4>> dimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo)) if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -179,7 +180,7 @@ public:
auto loc = op.getLoc(); auto loc = op.getLoc();
auto newRank = dimSizes.size(); auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) { if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -214,17 +215,18 @@ public:
numel); numel);
if (dimSizes.size() == 0) { if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
adaptor.getSelf()); adaptor.getSelf());
return success(); return success();
} }
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes); Value stablehloShape =
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>( rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
loc, mhloShape.getType(), numel, mhloShape); Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( loc, stablehloShape.getType(), numel, stablehloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
dims.push_back(r); dims.push_back(r);
} }
if (dims.size() == 0) { if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} }
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo; auto newDimSizes = *newDimSizesInfo;
auto mhloShape = auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape); op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success(); return success();
} }
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim); dims.erase(dims.begin() + dim);
if (dims.size() == 0) { if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} }
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo; auto newDimSizes = *newDimSizesInfo;
auto mhloShape = auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape); op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success(); return success();
} }
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant"); return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits); {dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo)) if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor"); "failed to create unsqueezed tensor");
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success(); return success();
} }
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
MLIRMemRefTransforms) MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs ChloPasses) list(APPEND LinkedLibs ChloPasses)
endif() endif()
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp Passes.cpp
VerifyLinalgOnTensorsBackendContract.cpp VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp VerifyTosaBackendContract.cpp
VerifyMhloBackendContract.cpp VerifyStablehloBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -21,9 +21,8 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif #endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to TOSA backend " "Pipeline lowering torch backend contract to TOSA backend "
"contract.", "contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline); TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>( mlir::PassPipelineRegistration<
"torch-backend-to-mhlo-backend-pipeline", TorchConversion::StablehloBackendPipelineOptions>(
"Pipeline lowering torch backend contract to MHLO backend " "torch-backend-to-stablehlo-backend-pipeline",
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.", "contract.",
TorchConversion::createTorchBackendToMhloBackendPipeline); TorchConversion::createTorchBackendToStablehloBackendPipeline);
#endif #endif
} }
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
} }
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
void TorchConversion::createTorchBackendToMhloBackendPipeline( void TorchConversion::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, OpPassManager &pm,
const TorchConversion::MhloBackendPipelineOptions &options) { const TorchConversion::StablehloBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass( // Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index)); options.enableStaticShape, options.enableI32Index));
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them. // The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass()); pm.addNestedPass<func::FuncOp>(createCSEPass());
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// MHLO backend contract. // StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that MHLO backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the // Verify that we have lowered to Stablehlo and Chlo ops.
// correct form. pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
} }
#endif #endif

View File

@ -6,10 +6,9 @@
// Also available under a BSD-style license. See LICENSE. // Also available under a BSD-style license. See LICENSE.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_MHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "PassDetail.h" #include "PassDetail.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
@ -18,6 +17,7 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
@ -25,17 +25,15 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
namespace { namespace {
class VerifyMhloBackendContractPass class VerifyStablehloBackendContractPass
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> { : public VerifyStablehloBackendContractBase<
VerifyStablehloBackendContractPass> {
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter; TypeConverter converter;
converter.addConversion([](Type type) -> Type { converter.addConversion([](Type type) -> Type {
auto elemTy = type; auto elemTy = type;
if (isa<TensorType>(type)) { if (isa<TensorType>(type))
elemTy = type.cast<TensorType>().getElementType(); elemTy = type.cast<TensorType>().getElementType();
}
if (BaseMemRefType::isValidElementType(elemTy)) if (BaseMemRefType::isValidElementType(elemTy))
return type; return type;
return nullptr; return nullptr;
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
MLIRContext *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
// Structural operations. // Structural operations.
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
// Shape operations. // Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes); target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>(); target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithDialect>(); target.addLegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the MHLO backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() { mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
return std::make_unique<VerifyMhloBackendContractPass>(); return std::make_unique<VerifyStablehloBackendContractPass>();
} }
#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCH_MLIR_ENABLE_STABLEHLO

View File

@ -20,6 +20,10 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h" #include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) { void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>(); registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>(); registry.insert<mlir::torch::Torch::TorchDialect>();
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses(); mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses(); mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
} }

View File

@ -44,9 +44,9 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to TOSA. # as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa" TOSA = "tosa"
# This output type consists of `mhlo` dialect ops. It can be thought of # This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to MHLO. # as taking the `TORCH` output type and lowering it to StableHLO.
MHLO = "mhlo" STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful # Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs. # for end-users, but can be convenient for development or reporting bugs.
@ -242,7 +242,7 @@ class ExampleArgs:
BACKEND_LEGAL_OPS = { BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.MHLO: [], OutputType.STABLEHLO: [],
} }
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
# We only allow `backend_legal_ops` to be specified for the `"torch"` # We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their # output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or MHLO), and those backends have # respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
# very specific requirements about the ops which are legal. # very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details. # See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None: if backend_legal_ops is not None:
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
print(mb.module) print(mb.module)
return mb.module return mb.module
elif output_type == OutputType.MHLO: elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
mb.module, mb.module,
"builtin.module(torch-backend-to-mhlo-backend-pipeline)", "builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> MHLO Backend IR") "Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose: if verbose:
print("\n====================") print("\n====================")
print("MHLO Backend IR") print("StableHLO Backend IR")
print(mb.module) print(mb.module)
return mb.module return mb.module
raise Exception(f"Unknown OutputType: {output_type}") raise Exception(f"Unknown OutputType: {output_type}")

View File

@ -7,6 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig from .torchscript import TorchScriptTestConfig
from .mhlo_backend import MhloBackendTestConfig from .stablehlo_backend import StablehloBackendTestConfig
from .tosa_backend import TosaBackendTestConfig from .tosa_backend import TosaBackendTestConfig
from .torchdynamo import TorchDynamoTestConfig from .torchdynamo import TorchDynamoTestConfig

View File

@ -8,12 +8,8 @@ from typing import Any
import torch import torch
import torch_mlir import torch_mlir
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
from torch_mlir_e2e_test.framework import ( from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
TestConfig,
Trace,
TraceItem
)
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
from .utils import ( from .utils import (
recursively_convert_to_numpy, recursively_convert_to_numpy,
@ -21,20 +17,20 @@ from .utils import (
) )
class MhloBackendTestConfig(TestConfig): class StablehloBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors. """Base class for TestConfig's that are implemented with linalg-on-tensors.
This class handles all the common lowering that torch-mlir does before This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level. reaching the linalg-on-tensors abstraction level.
""" """
def __init__(self, backend: MhloBackend):
def __init__(self, backend: StablehloBackend):
super().__init__() super().__init__()
self.backend = backend self.backend = backend
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward) example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile( module = torch_mlir.compile(program, example_args, output_type="stablehlo")
program, example_args, output_type="mhlo")
return self.backend.compile(module) return self.backend.compile(module)
@ -46,7 +42,6 @@ class MhloBackendTestConfig(TestConfig):
outputs = getattr(backend_module, item.symbol)(*numpy_inputs) outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = recursively_convert_from_numpy(outputs) output = recursively_convert_from_numpy(outputs)
result.append( result.append(
TraceItem(symbol=item.symbol, TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
inputs=item.inputs, )
output=output))
return result return result

View File

@ -10,29 +10,30 @@ import torch
from torch_mlir.ir import Module from torch_mlir.ir import Module
# A type shared between the result of `MhloBackend.compile` and the # A type shared between the result of `StablehloBackend.compile` and the
# input to `MhloBackend.load`. Each backend will likely have a # input to `StablehloBackend.load`. Each backend will likely have a
# different definition of this type. # different definition of this type.
CompiledArtifact = TypeVar('CompiledArtifact') CompiledArtifact = TypeVar("CompiledArtifact")
# A wrapper around a backend-specific loaded program representation # A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of # that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations. # Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker') Invoker = TypeVar("Invoker")
class MhloBackend(abc.ABC): class StablehloBackend(abc.ABC):
"""The interface to an MHLO backend. """The interface to an StableHLO backend.
Backends are recommended to raise meaningful exceptions in case of error, Backends are recommended to raise meaningful exceptions in case of error,
ideally with easy reproduction instructions. ideally with easy reproduction instructions.
""" """
@abc.abstractmethod @abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact: def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact. """Compile the provided MLIR module into a compiled artifact.
The module adheres to the MHLO backend contract The module adheres to the StableHLO backend contract
(see the VerifyMhloBackendContract pass). (see the VerifyStablehloBackendContract pass).
The compiled artifact can be any type, but must be correctly The compiled artifact can be any type, but must be correctly
interpreted by the `load` method. interpreted by the `load` method.

View File

@ -7,28 +7,32 @@ from torch_mlir.ir import *
from torch_mlir.passmanager import * from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
from .abc import MhloBackend from .abc import StablehloBackend
__all__ = [ __all__ = [
"LinalgOnTensorsMhloBackend", "LinalgOnTensorsStablehloBackend",
] ]
class LinalgOnTensorsMhloBackend(MhloBackend):
"""Main entry-point for the linalg-on-tensors based MHLO backend. class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""Main entry-point for the linalg-on-tensors based StableHLO backend.
This currently uses the linalg-on-tensors RefBackend for actual execution. This currently uses the linalg-on-tensors RefBackend for actual execution.
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend() self.refbackend = RefBackendLinalgOnTensorsBackend()
def compile(self, imported_module: Module): def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the MHLO backend contract. """Compiles an imported module that satisfied the StableHLO backend contract.
Args: Args:
imported_module: The MLIR module consisting of funcs in the MHLO imported_module: The MLIR module consisting of funcs in the StableHLO
dialect. dialect.
Returns: Returns:
An opaque, backend specific compiled artifact object that can be An opaque, backend specific compiled artifact object that can be
@ -36,8 +40,9 @@ class LinalgOnTensorsMhloBackend(MhloBackend):
""" """
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
imported_module, imported_module,
"builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))", "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
"Lowering MLIR-HLO to Linalg-on-Tensors") "Lowering StableHLO to Linalg-on-Tensors",
)
return self.refbackend.compile(imported_module) return self.refbackend.compile(imported_module)
def load(self, module): def load(self, module):

View File

@ -1,7 +1,7 @@
llvm_canonicalize_cmake_booleans( llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON MLIR_ENABLE_BINDINGS_PYTHON
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
TORCH_MLIR_ENABLE_MHLO TORCH_MLIR_ENABLE_STABLEHLO
) )
configure_lit_site_cfg( configure_lit_site_cfg(

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// ----- // -----
@ -7,7 +7,7 @@
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// ----- // -----
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> // CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
// ----- // -----
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> // CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> // CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> // CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64> // CHECK: return %[[T4]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index // CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32> // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32>
// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex> // CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex>
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32> // CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32>
func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> {
@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
// CHECK-LABEL: func.func @torch.aten.batch_norm$training( // CHECK-LABEL: func.func @torch.aten.batch_norm$training(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> // CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %true = torch.constant.bool true // CHECK: %true = torch.constant.bool true
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32> // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference( // CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> // CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32> // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32> // CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32> // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %none = torch.constant.none // CHECK: %none = torch.constant.none
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> // CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %true = torch.constant.bool true // CHECK: %true = torch.constant.bool true
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32> // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32> // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
// CHECK-LABEL: func @torch.aten.native_layer_norm( // CHECK-LABEL: func @torch.aten.native_layer_norm(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32> // CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32>
// CHECK: %int4 = torch.constant.int 4 // CHECK: %int4 = torch.constant.int 4
// CHECK: %int5 = torch.constant.int 5 // CHECK: %int5 = torch.constant.int 5
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %true = torch.constant.bool true // CHECK: %true = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64> // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64>
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> // CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> // CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> // CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> // CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> // CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> // CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> // CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> // CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32> // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32> // CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor> // CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32> // CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor> // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.gelu( // CHECK-LABEL: func.func @torch.aten.gelu(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -7,12 +7,12 @@
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor<?x?xf32> // CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32> // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor<?x?xf32> // CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32> // CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32> // CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-LABEL: func.func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// CHECK-LABEL: func.func @torch.aten.log$basic( // CHECK-LABEL: func.func @torch.aten.log$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.exp$basic( // CHECK-LABEL: func.func @torch.aten.exp$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.neg$basic( // CHECK-LABEL: func.func @torch.aten.neg$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic( // CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32> // CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64> // CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64> // CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> // CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32> // CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] // CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32> // CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] // CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "trunc" // CHECK: %[[STR:.*]] = torch.constant.str "trunc"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32> // CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32> // CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32> // CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "floor" // CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32> // CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.index_select$basic( // CHECK-LABEL: func.func @torch.aten.index_select$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
@ -10,8 +10,8 @@
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x?xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32> // CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x1x?xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> { func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {

View File

@ -1,10 +1,10 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.mm$basic$static( // CHECK-LABEL: func.func @torch.aten.mm$basic$static(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> // CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> // CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32> // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32> // CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> // CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> // CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32> // CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32> // CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[],f32> // CHECK: return %[[T4]] : !torch.vtensor<[],f32>
@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-LABEL: func.func @torch.aten.matmul$proj( // CHECK-LABEL: func.func @torch.aten.matmul$proj(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32> // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32> // CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc
// CHECK-LABEL: func.func @torch.aten.mm$proj( // CHECK-LABEL: func.func @torch.aten.mm$proj(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T_13:.*]] = torch.constant.bool false // CHECK: %[[T_13:.*]] = torch.constant.bool false
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) // CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32> // CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %false = torch.constant.bool false // CHECK: %false = torch.constant.bool false
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) // CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32> // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32> // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 // CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 // CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> // CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32> // CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32> // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> // CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]]) // CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> // CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> // CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> // CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) // CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> // CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) // CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32> // CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> // CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> // CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32> // CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> // CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 // CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 // CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 // CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> // CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> // CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> // CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> // CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> // CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]]) // CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>

View File

@ -1,2 +1,2 @@
if not config.enable_mhlo: if not config.enable_stablehlo:
config.unsupported = True config.unsupported = True

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// ----- // -----
@ -13,11 +13,11 @@
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32> // CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ // CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>): // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32> // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32> // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32> // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>): // CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32> // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32> // CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) // CHECK: })
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32> // CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32> // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64 // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64> // CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64> // CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64> // CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor<i64> // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ // CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>): // CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32> // CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64> // CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64> // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64> // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64> // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>) // CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>): // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32> // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: mhlo.return %[[IVAL_2]] : tensor<f32> // CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32> // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32> // CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 // CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32> // CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ // CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>): // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32> // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: mhlo.return %[[IVAL_5]] : tensor<f32> // CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32> // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({ // CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): // CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: mhlo.return %[[T10]] : tensor<f32> // CHECK: stablehlo.return %[[T10]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor<i64> // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32> // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32> // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32>

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -42,7 +42,7 @@
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32> // CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32> // CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32> // CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 // CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index // CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> // CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32> // CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index // CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> // CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32> // CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32> // CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> // CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32> // CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32> // CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32> // CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[],f32> // CHECK: return %[[T3]] : !torch.vtensor<[],f32>
func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32> // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32> // CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32> // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> // CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32> // CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32> // CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64> // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32> // CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {

View File

@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = "@Python3_EXECUTABLE@" config.python_executable = "@Python3_EXECUTABLE@"
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@ config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@
import lit.llvm import lit.llvm
lit.llvm.initialize(lit_config, config) lit.llvm.initialize(lit_config, config)

View File

@ -268,7 +268,7 @@ gentbl_cc_library(
( (
[ [
"-gen-pass-decls", "-gen-pass-decls",
"-DTORCH_MLIR_ENABLE_MHLO", "-DTORCH_MLIR_ENABLE_STABLEHLO",
], ],
"include/torch-mlir/Conversion/Passes.h.inc", "include/torch-mlir/Conversion/Passes.h.inc",
), ),
@ -434,13 +434,13 @@ cc_library(
) )
cc_library( cc_library(
name = "TorchMLIRTorchToMhlo", name = "TorchMLIRTorchToStablehlo",
srcs = glob([ srcs = glob([
"lib/Conversion/*.h", "lib/Conversion/*.h",
"lib/Conversion/TorchToMhlo/*.h", "lib/Conversion/TorchToStablehlo/*.h",
"lib/Conversion/TorchToMhlo/*.cpp", "lib/Conversion/TorchToStablehlo/*.cpp",
]), ]),
hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]), hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]),
strip_include_prefix = "include", strip_include_prefix = "include",
deps = [ deps = [
":TorchMLIRConversionPassesIncGen", ":TorchMLIRConversionPassesIncGen",
@ -465,8 +465,8 @@ cc_library(
":TorchMLIRTorchConversionToMLProgram", ":TorchMLIRTorchConversionToMLProgram",
":TorchMLIRTorchToArith", ":TorchMLIRTorchToArith",
":TorchMLIRTorchToLinalg", ":TorchMLIRTorchToLinalg",
":TorchMLIRTorchToMhlo",
":TorchMLIRTorchToSCF", ":TorchMLIRTorchToSCF",
":TorchMLIRTorchToStablehlo",
":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa", ":TorchMLIRTorchToTosa",
], ],
@ -489,8 +489,8 @@ cc_library(
":TorchMLIRTorchPasses", ":TorchMLIRTorchPasses",
":TorchMLIRTorchToArith", ":TorchMLIRTorchToArith",
":TorchMLIRTorchToLinalg", ":TorchMLIRTorchToLinalg",
":TorchMLIRTorchToMhlo",
":TorchMLIRTorchToSCF", ":TorchMLIRTorchToSCF",
":TorchMLIRTorchToStablehlo",
":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa", ":TorchMLIRTorchToTosa",
"@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ConversionPasses",

View File

@ -23,7 +23,7 @@ expand_template(
# All disabled, but required to substituted because they are not in quotes. # All disabled, but required to substituted because they are not in quotes.
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0", "@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
"@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0", "@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0",
"@TORCH_MLIR_ENABLE_MHLO@": "0", "@TORCH_MLIR_ENABLE_STABLEHLO@": "0",
}, },
template = "lit.site.cfg.py.in", template = "lit.site.cfg.py.in",
) )