mhlo: migrate conversion to stablehlo (#1840)

This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
pull/1851/head
Ashay Rane 2023-02-02 07:29:47 -06:00 committed by GitHub
parent ed9d8d1fb7
commit 711646d095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1190 additions and 1136 deletions

View File

@ -113,7 +113,7 @@ jobs:
-DLLVM_USE_HOST_TOOLS=ON \
-DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=OFF \
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
-DMACOSX_DEPLOYMENT_TARGET=12.0 \

View File

@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
@ -128,8 +128,8 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif()
if (TORCH_MLIR_ENABLE_MHLO)
set(MHLO_BUILD_EMBEDDED ON)
if (TORCH_MLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL)

View File

@ -267,8 +267,8 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v
echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v
echo ":::: Run StableHLO e2e integration tests"
python -m e2e_testing.main --config=stablehlo -v
echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v

View File

@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)
The terms "frontend" and "backend" are highly overloaded in any compiler
project, but frequently in Torch-MLIR this is the meaning that they have.
Sometimes "frontend" can mean something even further up the stack, such as
something in PyTorch itself. When there is ambiguity we will refer to this as
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
sitting below Linalg-on-Tensors, TOSA, or MHLO.
sitting below Linalg-on-Tensors, TOSA, or StableHLO.
## The `torch` dialect
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
The backend contract is a normalized form of the `torch` dialect with a set of
properties that make it easy to lower into various forms such as
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
box. The primary guarantees that we provide Torch-MLIR's backends are:
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
the box. The primary guarantees that we provide Torch-MLIR's backends are:
- All tensors have been converted to value semantics.
- All tensors have at least a known number of dimensions (i.e. rank), and
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
`tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)
### The Linalg Backend (Linalg-on-Tensors)
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
- It is extremely solid with static shapes (and many of its users only care
about static shapes, so that's fine).
### The MHLO Backend
### The StableHLO Backend
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
The MHLO backend was the third backend that we added, and it offers a reasonable
blend of the benefits of the other two.
The StableHLO backend was the third backend that we added, and it offers a
reasonable blend of the benefits of the other two.
- It is a coarse-grained named-op approach.
- It has a pretty clear spec for most of the ops (with a bit of mental
translation and hoping that MHLO is the same as HLO):
translation and hoping that StableHLO is the same as HLO):
https://www.tensorflow.org/xla/operation_semantics
- It functionally supports dynamic shapes (though not as coherent and consistent
as Linalg-on-Tensors, and the dynamic shape support falls outside the
@ -317,7 +317,7 @@ blend of the benefits of the other two.
example, TOSA limits (for highly considered reasons) the number of dimensions
that certain operators can handle to 1D-4D, when from a purely algebraic
perspective there isn't a good reason to not be more general. Similarly, more
general forms of reduction and scatter also fall into MHLO nicely while
general forms of reduction and scatter also fall into StableHLO nicely while
TOSA's principles tend to bias it away from that.
### Backend Implementation
@ -433,8 +433,9 @@ filling in some corners missing upstream and
to pull together upstream functionality into a working system.
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
RefBackend.
The RefBackend is absolutely not suitable for any production use case. It leaks
memory, doesn't support any error handling, performs no optimizations, and

View File

@ -34,7 +34,7 @@ and Clang's
- Eric Kunze (@eric-k256)
- Suraj Sudhir (@sjarus)
### TorchToMHLO
### TorchToStablehlo
- Tianyo Kwok (@tanyokwok)
- Ziheng Jiang (@ZihengJiang)

View File

@ -139,7 +139,7 @@ Ex:
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
## Jupyter

View File

@ -46,7 +46,7 @@ the ecosystem are:
- The frontend work required to lower TorchScript to the backend contract.
- The irregular support surface area of the large number of PyTorch ops across
the Linalg, TOSA, and MHLO backends.
the Linalg, TOSA, and StableHLO backends.
Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals.
@ -108,7 +108,7 @@ more advanced).
### Refactoring the backend
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
TOSA, and MHLO. These backends take IR in the backend contract form (see
TOSA, and StableHLO. These backends take IR in the backend contract form (see
[architecture.md](architecture.md)) and lowers them to the respective dialects.
Today, each backend is implemented completely independently. This leads to
duplication and irregularity across the backends.
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
forward-looking efforts that intersect with this effort:
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
initially forked from MHLO which intends to create a stable support surface
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
complete op set, so it is very attractive to have "almost all" models
bottleneck through a stable interface like StableHLO. StableHLO is currently
under relatively early development, but already delivers on many of the goals
of stability.
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
attractive to have "almost all" models bottleneck through a stable interface
like StableHLO. StableHLO is currently under relatively early development,
but already delivers on many of the goals of stability.
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
which could serve a role very similar to MHLO, while providing community
ownership. TCP is still in early planning phases, but there is strong

View File

@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.configs import (
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
StablehloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
)
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
@ -42,7 +42,7 @@ def _get_argparse():
help=f"""
Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"mhlo": run through torch-mlir"s default MHLO backend.
"stablehlo": run through torch-mlir"s default StableHLO backend.
"tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
@ -80,9 +80,9 @@ def main():
if args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == "mhlo":
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
if args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = {}

View File

@ -87,8 +87,10 @@ TORCHDYNAMO_XFAIL_SET = {
"StdCorrectionKeepDimModule_basic",
}
MHLO_PASS_SET = {
STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
@ -103,6 +105,7 @@ MHLO_PASS_SET = {
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
@ -124,12 +127,15 @@ MHLO_PASS_SET = {
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic",
@ -198,6 +204,8 @@ MHLO_PASS_SET = {
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardsigmoidModule_basic",
@ -220,6 +228,8 @@ MHLO_PASS_SET = {
"MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic",
"MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic",
"Mv_basic",
"NativeLayerNormModule4D_basic",
@ -251,6 +261,8 @@ MHLO_PASS_SET = {
"LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic",
"NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic",
@ -420,6 +432,7 @@ MHLO_PASS_SET = {
"UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic",
"TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development

View File

@ -1,14 +0,0 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")

View File

@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")

View File

@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()

View File

@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to mhlo ops.
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
// Specify any options.
let options = [

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
@ -16,10 +16,11 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()

View File

@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
struct MhloBackendPipelineOptions
: public PassPipelineOptions<MhloBackendPipelineOptions> {
// Do not register the stablehlo options if the stablehlo target is disabled
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct StablehloBackendPipelineOptions
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
Option<bool> enableStaticShape{
*this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
llvm::cl::init(false)};
};
void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, const MhloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

View File

@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the mhlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO
#endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -3,13 +3,7 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
set(LinkedLibs
MLIRFuncDialect
MLIRIR
MLIRSupport
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
TorchMLIRRefBackend
)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
MhloPasses
MhloToLinalg
StablehloToMhlo
)
endif()
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
${LinkedLibs}
)
torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()
add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram)
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs
MhloPasses
TorchMLIRTorchToMhlo)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
endif()
add_mlir_library(TorchMLIRConversionPasses

View File

@ -9,15 +9,15 @@
#include "torch-mlir/Conversion/Passes.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
@ -32,12 +32,4 @@ namespace {
void mlir::torch::registerConversionPasses() {
::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createSymbolicShapeOptimizationPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
}

View File

@ -1,35 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
MhloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
DEPENDS
MhloDialect
MhloToLinalg
MLIRMhloPassIncGen
LMHLOTransformsPassIncGen
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MhloDialect
MhloToLinalg
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToMhlo)

View File

@ -1,74 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -29,7 +30,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
if (selfRank > otherRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits);
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits);
if (failed(unsqueezeInfo))
return failure();
other = *unsqueezeInfo;
} else if (otherRank > selfRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
unsqueezeDims, dimSizeIndexBits);
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
dimSizeIndexBits);
if (failed(unsqueezeInfo))
return failure();
self = *unsqueezeInfo;
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMaxValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMinValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
// These legalizations are for unary ops.
namespace {
template <typename AtenOpT, typename MhloOpT>
template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -137,13 +142,13 @@ public:
Value self = adaptor.getSelf();
auto selfType = self.getType().cast<TensorType>();
if (!selfType) {
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
self = mhlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
self = hlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
}
};
@ -152,7 +157,7 @@ public:
// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
namespace {
template <typename AtenOpT, typename MhloOpT>
template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -164,10 +169,10 @@ public:
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<MhloOpT>(
rewriter.replaceOpWithNewOp<StablehloOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -198,7 +203,7 @@ public:
.template dyn_cast<TensorType>();
if (!outType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
@ -216,9 +221,9 @@ public:
SmallVector<int32_t> values(size, fillVal);
auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
return success();
}
};
@ -247,8 +252,8 @@ public:
->convertType(op.getType())
.template cast<TensorType>();
lhs = mhlo::promoteType(rewriter, lhs, outTy);
rhs = mhlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
@ -274,7 +279,7 @@ public:
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -287,18 +292,19 @@ public:
}
if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs);
}
}
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
@ -328,7 +334,7 @@ public:
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -343,11 +349,12 @@ public:
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -368,15 +375,15 @@ public:
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
return success();
@ -401,7 +408,7 @@ public:
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsTy)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -414,11 +421,12 @@ public:
}
if (!rhsTy) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
lhsElemTy);
}
// TODO: what is the PyTorch default type promotion?
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr;
@ -485,8 +493,8 @@ public:
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -537,8 +545,8 @@ public:
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation);
return success();
}
};
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
Value self = adaptor.getSelf();
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
return success();
}
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
} else {
Value inputRank = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
inputRank);
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getIndexType(), dim);
}
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenWhereSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value cond = adaptor.getCondition();
Value other = adaptor.getOther();
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
return op.emitError("failed broadcast other and condition ranks");
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
op,
getTypeConverter()->convertType(op.getType()),
op, getTypeConverter()->convertType(op.getType()),
ArrayRef<Value>{cond, self, other});
return success();
}
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
.cast<RankedTensorType>();
if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
return success();
}
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers =
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
}
@ -708,8 +715,8 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation);
return success();
}
@ -721,7 +728,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
} else {
@ -751,16 +758,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
});
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
valueAttr);
return success();
}
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
adaptor.getValue());
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
adaptor.getValue());
return success();
}
// AtenReciprocalOp
// Reciprocal(x) = Div(1, x)
template <>
@ -777,7 +784,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
}
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
@ -790,9 +797,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto outputElemType = outputType.getElementType();
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, mhloTensor);
Value stablehloTensor = hlo::scalarToStablehloTensor(
rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, stablehloTensor);
return success();
}
@ -815,7 +822,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success();
}
// AtenReluOp
// Relu(x) = Max(0, x)
template <>
@ -836,11 +842,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
return success();
}
// Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
template <>
@ -857,12 +862,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
return success();
}
@ -881,7 +886,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success();
}
// AtenBatchNormOp
template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
@ -919,28 +923,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) {
weight = mhlo::getConstantOfShape(
weight = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, bias))) {
bias = mhlo::getConstantOfShape(
bias = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningVar))) {
runningVar = mhlo::getConstantOfShape(
runningVar = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningMean))) {
runningMean = mhlo::getConstantOfShape(
runningMean = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
@ -983,10 +987,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
return success();
} else {
@ -995,10 +1000,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0];
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
// Feature counts must match among operands of
// stablehlo::BatchNormInferenceOp.
Value inputCasted =
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute.
@ -1008,7 +1014,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
}
}
// AtenNativeLayerNormOp
template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
@ -1076,21 +1081,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
}
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
numEmbeddingDimSize};
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
auto mhloBatchNormOutTy =
auto stablehloBatchNormOutTy =
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
auto mhloBathNormOutMeanOrVarTy =
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
meanOrVarStablehloOutShape, inputTy.getElementType());
// Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), stablehloBatchNormOutTy, input,
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
.value());
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>()
@ -1103,16 +1108,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto oneOrZeroConstType =
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
Value scale = rewriter.create<mhlo::ConstantOp>(
Value scale = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
Value offset = rewriter.create<mhlo::ConstantOp>(
Value offset = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op->getLoc(), stablehloBatchNormOutTy,
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
// Reshape back
auto outputTy =
@ -1120,36 +1127,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto outputMeanOrVarTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
auto output = rewriter.create<mhlo::DynamicReshapeOp>(
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())})
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())})
.value());
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
mhlo::getConstTensor(
hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value());
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
mhlo::getConstTensor(
hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value());
// Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto outputMulWeight =
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput =
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput = rewriter.create<stablehlo::AddOp>(
op->getLoc(), outputMulWeight, bcastedBias);
rewriter.replaceOp(op, {finalOuput, mean, var});
return success();
}
// AtenCatOp
template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
@ -1173,11 +1179,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type
for (auto &v : builtinTensors) {
v = mhlo::promoteType(rewriter, v, outType);
v = hlo::promoteType(rewriter, v, outType);
}
size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, outType, ValueRange(builtinTensors), posDim);
return success();
}
@ -1225,7 +1231,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(checkNotNone(rewriter, op, minValue))) {
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) {
return rewriter.notifyMatchFailure(
@ -1233,7 +1240,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
}
minValue = *minInfo;
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
minValue =
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) {
return rewriter.notifyMatchFailure(
@ -1241,10 +1249,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
}
maxValue = *maxInfo;
} else {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
minValue =
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
}
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
return success();
}
@ -1266,24 +1277,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: only int or float dtype supported");
}
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
Value start =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
Value end =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
// Get length of the 1-d output tensor
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
Value resultLength = rewriter.create<mhlo::ReshapeOp>(
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<mhlo::ConvertOp>(
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
}
Value window =
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions);
@ -1298,9 +1312,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto outType = this->getTypeConverter()
->convertType(op.getType())
.cast<TensorType>();
auto outType =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
if (!outType) {
return op.emitError("only tensor type is supported");
}
@ -1320,26 +1333,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
Value erfArg =
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
Value kBeta0 =
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
adaptor.getSelf());
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<mhlo::MulOp>(
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<stablehlo::MulOp>(
loc, outType, adaptor.getSelf(), adaptor.getSelf());
Value negHalfInputSquared =
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
Value expRes =
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
Value pdfTimesInput =
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
Value pdfTimesInputAddCdf =
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
pdfTimesInputAddCdf);
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
return success();
}
@ -1366,9 +1380,9 @@ public:
};
} // namespace
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>();
@ -1376,23 +1390,24 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
#undef INSERT_UNARY_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
@ -1482,10 +1497,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
context)
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
typeConverter, context)
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);

View File

@ -0,0 +1,29 @@
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -24,7 +25,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t, 4> startIndexMap(1, axis);
// indexVecDim
int64_t indexVecDim = indicesRank;
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedSliceDims,
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
auto outputTy =
RankedTensorType::get(outputShape, inputRankTy.getElementType());
return rewriter
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr)
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr)
.getResult();
}
} // namespace
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// Ref:
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional)
// If specified, the entries at padding_idx do not contribute to the gradient;
// therefore, the embedding vector at padding_idx is not updated during training,
// i.e. it remains as a fixed “pad”.
// If specified, the entries at padding_idx do not contribute to the
// gradient; therefore, the embedding vector at padding_idx is not updated
// during training, i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional)
// If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False.
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis(
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
return success();
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
return success();
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
auto options = getOptions();
auto indexShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param");
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
loc, toConcatIndexType, index, toConcatIndexShape));
} else {
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
loc, toConcatIndexType, toConcatIndexShape,
rewriter.getI64IntegerAttr(i)));
}
}
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
startIndexMap.push_back(i);
}
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
return success();
}
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
ArrayRef<int64_t> broadcastDims) {
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
auto loc = op->getLoc();
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType());
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, mhloShape, broadcastAttr);
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr);
return broadcast;
}
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
ArrayRef<int64_t> inpTransDims) {
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank();
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
auto inpShape = inputTy.getShape();
std::vector<int64_t> newShape;
newShape.reserve(rank);
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
permuteAttr);
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
input, permuteAttr);
return result.getResult();
}
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
}
// set result dimensions
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
lhsResultDim >= 0) {
outShape.push_back(lhsShape[lhsResultDim]);
}
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
rhsResultDim >= 0) {
outShape.push_back(rhsShape[rhsResultDim]);
}
return RankedTensorType::get(outShape, lhsTy.getElementType());
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
dimSizeIndexBits);
auto lhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
dimSizeIndexBits);
auto rhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -218,8 +221,8 @@ public:
if (lhsRank <= 2 && rhsRank <= 2) {
auto tensorType =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
nullptr);
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
rhs, nullptr);
return success();
}
@ -253,8 +256,8 @@ public:
lhsContractingDim = nBatchDims;
}
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
@ -264,8 +267,8 @@ public:
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.getResult();
return success();
}
@ -312,7 +315,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
return success();
}
@ -335,7 +338,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
@ -371,7 +374,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
@ -398,10 +401,10 @@ public:
auto bias = adaptor.getBias();
auto biasTy = bias.getType();
// MHLO does not mandate that elementwise op tensors need to be ranked.
// StableHLO does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() &&
!biasTy.template isa<RankedTensorType>())
return op.emitError("only ranked tensor types are supported in MHLO "
return op.emitError("only ranked tensor types are supported in StableHLO "
"matmul for bias tensor");
// weight.T
@ -427,14 +430,14 @@ public:
auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
Value matmulPlusBias = matmulOutput;
@ -464,7 +467,7 @@ public:
auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank();
const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank);
@ -488,7 +491,7 @@ public:
}
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
@ -497,7 +500,7 @@ public:
for (int64_t i = 0; i <= rank; i++)
transposeDims[i] = i;
std::swap(transposeDims[1], transposeDims[0]);
weight = rewriter.create<mhlo::TransposeOp>(
weight = rewriter.create<stablehlo::TransposeOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
@ -509,7 +512,7 @@ public:
weightShapeVec[1] = OCMulGValue;
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
return weight;
@ -544,25 +547,27 @@ public:
}
// Prepare for transposed convolution
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride =
rewriter.getI64TensorAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
mhloPaddingVec[i * 2] = padInt;
mhloPaddingVec[i * 2 + 1] = padInt;
stablehloPaddingVec[i * 2] = padInt;
stablehloPaddingVec[i * 2 + 1] = padInt;
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
mhloPaddingVec);
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
DenseIntElementsAttr mhloLhsDilation =
rewriter.getI64TensorAttr(mhloLhsDilationVec);
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
DenseIntElementsAttr mhloRhsDilation =
rewriter.getI64TensorAttr(mhloRhsDilationVec);
stablehloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(),
stablehloRhsDilationVec.begin());
DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
@ -571,8 +576,8 @@ public:
for (int i = 0; i < nSpatialDims; ++i) {
spatialDims.push_back(i + 2);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get(
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDims,
@ -583,17 +588,18 @@ public:
/*outputSpatialDimensions=*/spatialDims);
// Reverse and transpose weight
weight = rewriter.create<mhlo::ReverseOp>(
weight = rewriter.create<stablehlo::ReverseOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
if (groups != 1) {
weight = reshapeConvWeight(rewriter, op, weight, groups);
}
// Create transposed convolution
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
static_cast<uint64_t>(groups), 1, precisionConfig);
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, stablehloStride,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
// Handle output padding
if (!needHandleOutputPadding) {
@ -605,8 +611,8 @@ public:
std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2);
Value paddingValue =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh =
@ -614,7 +620,7 @@ public:
mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec);
auto paddedOutput = rewriter.create<mhlo::PadOp>(
auto paddedOutput = rewriter.create<stablehlo::PadOp>(
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
edgePaddingHigh, interiorPadding);
@ -628,22 +634,22 @@ public:
ArrayRef<int64_t> dilation, int64_t groups) const {
int64_t nDims = outType.getRank();
// Get mhlo::ConvolutionOp attributes
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
// Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()),
stride);
std::vector<int64_t> mhloPaddingVec;
std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) {
mhloPaddingVec.emplace_back(padding[i]);
mhloPaddingVec.emplace_back(padding[i]);
stablehloPaddingVec.emplace_back(padding[i]);
stablehloPaddingVec.emplace_back(padding[i]);
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
mhloPaddingVec);
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
stablehloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
@ -651,8 +657,8 @@ public:
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get(
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDimensions,
@ -662,17 +668,18 @@ public:
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDimensions);
// mhlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr mhloLhsDilation;
// stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
static_cast<uint64_t>(groups), 1, precisionConfig);
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, stablehloWindowStride,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
return mhloConvOp.getResult();
return stablehloConvOp.getResult();
}
LogicalResult
@ -754,21 +761,22 @@ public:
}
}
Value mhloConvResult;
Value stablehloConvResult;
if (transposed) {
mhloConvResult = convertTransposedConv(
stablehloConvResult = convertTransposedConv(
op, rewriter, outTy, input, weight, stride, padding, dilation,
outputPadding, groups, needHandleOutputPadding);
} else {
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
stride, padding, dilation, groups);
stablehloConvResult =
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
dilation, groups);
}
auto bias = adaptor.getBias();
// No bias provided
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
rewriter.replaceOp(op, mhloConvResult);
rewriter.replaceOp(op, stablehloConvResult);
return success();
}
@ -790,21 +798,21 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
const auto &options = getOptions();
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy);
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
bias, bcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, stablehloConvResult, bias, bcastDimensions);
return success();
}
};
} // namespace
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,7 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
@ -40,14 +41,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
op->emitError("unimplemented lowering in AtenPoolingOp");
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
stablehloPadding);
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto initIndexTensor =
rewriter
.create<mhlo::DynamicIotaOp>(
.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(initIndexShapeForType,
rewriter.getI64Type()),
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto indexTensor =
rewriter
.create<mhlo::DynamicReshapeOp>(
.create<stablehlo::DynamicReshapeOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
initIndexTensor, inputShapeTensor)
.getResult();
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr;
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>(
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx =
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>(
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloKernelSize.begin() + inputRank - 2);
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
stablehloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
// Use kernel size as the divisor
if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>(
Value divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value();
divisor = mhlo::promoteType(rewriter, divisor, outTy);
divisor = hlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
}
// Use another mhlo.ReduceWindowOp to get the divisor
// Use another stablehlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
const auto &options = getOptions();
auto inputShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
windowDilations, pad);
@ -522,11 +526,11 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success();
}
@ -560,33 +564,33 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
mhloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
mhloPadding[dim * 2] = inputShape[dim] - 1;
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
stablehloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
stablehloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
@ -604,17 +608,17 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOp(op, reduceWindowSum.getResults());
return success();
}
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);

View File

@ -0,0 +1,69 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_stablehlo {
struct TorchToStablehloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToStablehloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToStablehloOptions &getOptions() const { return options; }
private:
TorchToStablehloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
} // namespace torch_to_stablehlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
return std::nullopt;
Value initIndex;
if (dimSizeIndexBits == 32) {
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else {
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim));
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), ValueRange{input, indexTensor},
ValueRange{
initValue,
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
},
dimensions);
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
// Add block arguments
auto blockValArgumentType =
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr;
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>(
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx =
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>(
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
return mhloReduceOp.getResults();
return stablehloReduceOp.getResults();
}
namespace {
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenArgmaxOp to MHLO");
"AtenArgmaxOp to StableHLO");
}
int64_t dim;
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
outShapeTensor);
return success();
}
rewriter.replaceOp(op, mhloReduceResults[1]);
rewriter.replaceOp(op, stablehloReduceResults[1]);
return success();
}
} // namespace
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxDimOp to MHLO");
"AtenMaxDimOp to StableHLO");
}
RankedTensorType valResultType = getTypeConverter()
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success();
}
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
} // namespace
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
}
auto inputElemTy = inputTy.getElementType();
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumOp to MHLO");
"AtenSumOp to StableHLO");
}
SmallVector<int64_t> dims;
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>(
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to MHLO");
"AtenMaxOp to StableHLO");
}
SmallVector<int64_t> dims;
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<mhlo::MaxOp>(
Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
}
auto inputElemTy = inputTy.getElementType();
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumDimIntListOp to MHLO");
"AtenSumDimIntListOp to StableHLO");
}
SmallVector<int64_t> inputDims;
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Region &region = mhloReduceOp.getBody();
Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>(
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResult(0), outShapeTensor);
stablehloReduceOp.getResult(0), outShapeTensor);
return success();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenFrobeniusNormDimOp
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
// + mhlo.sqrt
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
// dims)
// + stablehlo.sqrt
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const TorchToMhloOptions &options = getOptions();
const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
}
// Sort the dims in ascending order, making the conversion
// Sort the dims in ascending order, making the conversion
// stable with unordered dims.
std::sort(dims.begin(), dims.end());
@ -624,14 +642,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported");
}
auto squareOp = rewriter.create<mhlo::MulOp>(op->getLoc(), input, input);
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) {
return failure();
}
auto reduceOp = rewriter.create<mhlo::ReduceOp>(
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
@ -649,30 +667,32 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), firstArgument,
secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
auto addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), firstArgument, secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
}
auto output =
rewriter.create<mhlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor);
return success();
@ -682,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
} // namespace
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \

View File

@ -7,11 +7,12 @@
//
//===----------------------------------------------------------------------===//
#include "./MhloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric>
@ -21,27 +22,27 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir {
namespace mhlo {
namespace hlo {
// Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) {
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) {
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype) {
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<mhlo::ReshapeOp>(
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<stablehlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor);
}
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
input);
}
return input;
}
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
input =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
input);
}
ArrayRef<int64_t> inShape = in_type.getShape();
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()),
bcastDims);
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
input, bcast_attr);
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, input, bcast_attr);
return bcast_op.getResult();
}
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
return rewriter
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
rewriter.getI64TensorAttr({}))
.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
.getResult();
}
} // namespace mhlo
} // namespace hlo
} // namespace mlir

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@ -18,22 +18,22 @@
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace hlo {
using mlir::ConversionPatternRewriter;
// Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
// Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val);
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val);
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
@ -42,8 +42,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype);
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
} // namespace mhlo
} // namespace hlo
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H

View File

@ -7,17 +7,18 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
class ConvertTorchToStablehlo
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
public:
ConvertTorchToMhlo() = default;
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
ConvertTorchToStablehlo() = default;
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<stablehlo::StablehloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
@ -48,7 +50,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
tensor::TensorDialect, arith::ArithDialect>();
TypeConverter typeConverter;
@ -57,20 +59,20 @@ public:
RewritePatternSet patterns(context);
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
enableI32Index ? 32u : 64u};
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
torch_to_stablehlo::TorchToStablehloOptions options{
enableStaticShape, enableI32Index ? 32u : 64u};
torch_to_stablehlo::populateBasicOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populateReductionOpPatternsAndLegality(
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
@ -82,13 +84,13 @@ public:
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>(false, false);
mlir::torch::createConvertTorchToStablehloPass() {
return std::make_unique<ConvertTorchToStablehlo>(false, false);
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
enableI32Index);
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
enableI32Index);
}

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,7 +29,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize].
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
return rewriter.create<mhlo::RealDynamicSliceOp>(
return rewriter.create<stablehlo::RealDynamicSliceOp>(
loc, outTy, input, startTensor, endTensor, stridesTensor);
}
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
}
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -179,7 +180,7 @@ public:
auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -214,17 +215,18 @@ public:
numel);
if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
Value stablehloShape =
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
loc, stablehloShape.getType(), numel, stablehloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
dims.push_back(r);
}
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto mhloShape =
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success();
}
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim);
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto mhloShape =
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success();
}
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success();
}
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
TorchMLIRTorchConversionToMLProgram
MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs ChloPasses)
endif()
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyMhloBackendContract.cpp
VerifyStablehloBackendContract.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -21,9 +21,8 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_MHLO
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
"torch-backend-to-mhlo-backend-pipeline",
"Pipeline lowering torch backend contract to MHLO backend "
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::PassPipelineRegistration<
TorchConversion::StablehloBackendPipelineOptions>(
"torch-backend-to-stablehlo-backend-pipeline",
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.",
TorchConversion::createTorchBackendToMhloBackendPipeline);
TorchConversion::createTorchBackendToStablehloBackendPipeline);
#endif
}
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
}
#ifdef TORCH_MLIR_ENABLE_MHLO
void TorchConversion::createTorchBackendToMhloBackendPipeline(
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm,
const TorchConversion::MhloBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
const TorchConversion::StablehloBackendPipelineOptions &options) {
// Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Clean up any non-canonical code introduced above..
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// MHLO backend contract.
// StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that MHLO backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
// Verify that we have lowered to Stablehlo and Chlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
}
#endif

View File

@ -6,10 +6,9 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_MHLO
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "PassDetail.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@ -18,6 +17,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
@ -25,17 +25,15 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace {
class VerifyMhloBackendContractPass
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
class VerifyStablehloBackendContractPass
: public VerifyStablehloBackendContractBase<
VerifyStablehloBackendContractPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter;
converter.addConversion([](Type type) -> Type {
auto elemTy = type;
if (isa<TensorType>(type)) {
if (isa<TensorType>(type))
elemTy = type.cast<TensorType>().getElementType();
}
if (BaseMemRefType::isValidElementType(elemTy))
return type;
return nullptr;
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// Structural operations.
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
// Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the MHLO backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
return std::make_unique<VerifyMhloBackendContractPass>();
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
return std::make_unique<VerifyStablehloBackendContractPass>();
}
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO

View File

@ -20,6 +20,10 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
}

View File

@ -44,9 +44,9 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa"
# This output type consists of `mhlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to MHLO.
MHLO = "mhlo"
# This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
@ -242,7 +242,7 @@ class ExampleArgs:
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.MHLO: [],
OutputType.STABLEHLO: [],
}
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
# We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or MHLO), and those backends have
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
# very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None:
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
print(mb.module)
return mb.module
elif output_type == OutputType.MHLO:
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-mhlo-backend-pipeline)",
"Lowering Torch Backend IR -> MHLO Backend IR")
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("MHLO Backend IR")
print("StableHLO Backend IR")
print(mb.module)
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")

View File

@ -7,6 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig
from .mhlo_backend import MhloBackendTestConfig
from .stablehlo_backend import StablehloBackendTestConfig
from .tosa_backend import TosaBackendTestConfig
from .torchdynamo import TorchDynamoTestConfig

View File

@ -8,12 +8,8 @@ from typing import Any
import torch
import torch_mlir
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
from torch_mlir_e2e_test.framework import (
TestConfig,
Trace,
TraceItem
)
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
from .utils import (
recursively_convert_to_numpy,
@ -21,20 +17,20 @@ from .utils import (
)
class MhloBackendTestConfig(TestConfig):
class StablehloBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: MhloBackend):
def __init__(self, backend: StablehloBackend):
super().__init__()
self.backend = backend
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="mhlo")
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
return self.backend.compile(module)
@ -46,7 +42,6 @@ class MhloBackendTestConfig(TestConfig):
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = recursively_convert_from_numpy(outputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
)
return result

View File

@ -10,29 +10,30 @@ import torch
from torch_mlir.ir import Module
# A type shared between the result of `MhloBackend.compile` and the
# input to `MhloBackend.load`. Each backend will likely have a
# A type shared between the result of `StablehloBackend.compile` and the
# input to `StablehloBackend.load`. Each backend will likely have a
# different definition of this type.
CompiledArtifact = TypeVar('CompiledArtifact')
CompiledArtifact = TypeVar("CompiledArtifact")
# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')
Invoker = TypeVar("Invoker")
class MhloBackend(abc.ABC):
"""The interface to an MHLO backend.
class StablehloBackend(abc.ABC):
"""The interface to an StableHLO backend.
Backends are recommended to raise meaningful exceptions in case of error,
ideally with easy reproduction instructions.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the MHLO backend contract
(see the VerifyMhloBackendContract pass).
The module adheres to the StableHLO backend contract
(see the VerifyStablehloBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.

View File

@ -7,28 +7,32 @@ from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
from .abc import MhloBackend
from .abc import StablehloBackend
__all__ = [
"LinalgOnTensorsMhloBackend",
"LinalgOnTensorsStablehloBackend",
]
class LinalgOnTensorsMhloBackend(MhloBackend):
"""Main entry-point for the linalg-on-tensors based MHLO backend.
class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""Main entry-point for the linalg-on-tensors based StableHLO backend.
This currently uses the linalg-on-tensors RefBackend for actual execution.
"""
def __init__(self):
super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend()
def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the MHLO backend contract.
"""Compiles an imported module that satisfied the StableHLO backend contract.
Args:
imported_module: The MLIR module consisting of funcs in the MHLO
imported_module: The MLIR module consisting of funcs in the StableHLO
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
@ -36,8 +40,9 @@ class LinalgOnTensorsMhloBackend(MhloBackend):
"""
run_pipeline_with_repro_report(
imported_module,
"builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))",
"Lowering MLIR-HLO to Linalg-on-Tensors")
"builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))",
"Lowering StableHLO to Linalg-on-Tensors",
)
return self.refbackend.compile(imported_module)
def load(self, module):

View File

@ -1,7 +1,7 @@
llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
TORCH_MLIR_ENABLE_MHLO
TORCH_MLIR_ENABLE_STABLEHLO
)
configure_lit_site_cfg(

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// -----
@ -7,7 +7,7 @@
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// -----
// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32>
func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
// -----
// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64>
// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64>
func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> {
@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor<?x?xf32>
// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex>
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<8x4x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32>
func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> {
@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],
// CHECK-LABEL: func.func @torch.aten.batch_norm$training(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>)
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
// CHECK: %none = torch.constant.none
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor<f32>, tensor<1xindex>) -> tensor<3xf32>
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>)
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32>
func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
// CHECK-LABEL: func @torch.aten.native_layer_norm(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32>
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32>
// CHECK: %int4 = torch.constant.int 4
// CHECK: %int5 = torch.constant.int 5
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
// CHECK: %true = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64>
// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32>
func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> {
@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) ->
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.gelu(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -7,12 +7,12 @@
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor<?x?xf32>
// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
// CHECK-LABEL: func.func @torch.aten.log$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.exp$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.neg$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor<?x?xi32>) -> tensor<?x?xi64>
// CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor<?x?xi64>, tensor<?x?xi64>) -> tensor<?x?xi64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xi64> -> !torch.vtensor<[?,?],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64>
@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32>
// CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32>
func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> {
@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64>
// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64>
// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]]
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.index_select$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
@ -10,8 +10,8 @@
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x4xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x4xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32>
func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> {
@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?xi64>, tensor<2xi64>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> {
@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
// CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false} : (tensor<?x?xf32>, tensor<?x1xi64>, tensor<2xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<?x1x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> {

View File

@ -1,10 +1,10 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.mm$basic$static(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK-LABEL: func.func @torch.aten.matmul$proj(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc
// CHECK-LABEL: func.func @torch.aten.mm$proj(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32>
// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T_13:.*]] = torch.constant.bool false
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
// CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %false = torch.constant.bool false
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]])
// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32>
// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]])
// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>

View File

@ -1,2 +1,2 @@
if not config.enable_mhlo:
if not config.enable_stablehlo:
config.unsupported = True

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// -----
@ -13,11 +13,11 @@
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32>
@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: mhlo.return %[[VAL_10]] : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: stablehlo.return %[[VAL_10]] : tensor<f32>
// CHECK: })
// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64>
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64
// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64>
// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor<?x?xi64>
// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor<?x?xi64>, tensor<3xi64>) -> tensor<?x?x?xi64>
// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor<i64>
// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i64>, %[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<i64>):
// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor<i1>, tensor<f32>
// CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor<i64>
// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor<i1>, tensor<i64>
// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor<i1>, tensor<i64>
// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor<f32>, tensor<i64>
// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<?x?x?xf32>, tensor<?x?x?xi64>, tensor<f32>, tensor<i64>) -> (tensor<?x?x?xf32>, tensor<?x?x?xi64>)
// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor<?x?x?xi64> -> !torch.vtensor<[?,?,?],si64>
@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[IVAL_0:.*]]: tensor<f32>, %[[IVAL_1:.*]]: tensor<f32>):
// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: mhlo.return %[[IVAL_2]] : tensor<f32>
// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_2]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64
@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64
// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor<f32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({
// CHECK: ^bb0(%[[IVAL_3:.*]]: tensor<f32>, %[[IVAL_4:.*]]: tensor<f32>):
// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: mhlo.return %[[IVAL_5]] : tensor<f32>
// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor<f32>
// CHECK: stablehlo.return %[[IVAL_5]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor<?x?x?x?xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({
// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: mhlo.return %[[T10]] : tensor<f32>
// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
// CHECK: stablehlo.return %[[T10]] : tensor<f32>
// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor<i64>
// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor<i64>
// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor<i64>) -> tensor<f32>
// CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32>

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -42,7 +42,7 @@
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<f32>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32>
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32>
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64>
// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32>
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) ->
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32>
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64>
// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32>
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {

View File

@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = "@Python3_EXECUTABLE@"
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@
config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@
import lit.llvm
lit.llvm.initialize(lit_config, config)

View File

@ -268,7 +268,7 @@ gentbl_cc_library(
(
[
"-gen-pass-decls",
"-DTORCH_MLIR_ENABLE_MHLO",
"-DTORCH_MLIR_ENABLE_STABLEHLO",
],
"include/torch-mlir/Conversion/Passes.h.inc",
),
@ -434,13 +434,13 @@ cc_library(
)
cc_library(
name = "TorchMLIRTorchToMhlo",
name = "TorchMLIRTorchToStablehlo",
srcs = glob([
"lib/Conversion/*.h",
"lib/Conversion/TorchToMhlo/*.h",
"lib/Conversion/TorchToMhlo/*.cpp",
"lib/Conversion/TorchToStablehlo/*.h",
"lib/Conversion/TorchToStablehlo/*.cpp",
]),
hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]),
hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]),
strip_include_prefix = "include",
deps = [
":TorchMLIRConversionPassesIncGen",
@ -465,8 +465,8 @@ cc_library(
":TorchMLIRTorchConversionToMLProgram",
":TorchMLIRTorchToArith",
":TorchMLIRTorchToLinalg",
":TorchMLIRTorchToMhlo",
":TorchMLIRTorchToSCF",
":TorchMLIRTorchToStablehlo",
":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa",
],
@ -489,8 +489,8 @@ cc_library(
":TorchMLIRTorchPasses",
":TorchMLIRTorchToArith",
":TorchMLIRTorchToLinalg",
":TorchMLIRTorchToMhlo",
":TorchMLIRTorchToSCF",
":TorchMLIRTorchToStablehlo",
":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa",
"@llvm-project//mlir:ConversionPasses",

View File

@ -23,7 +23,7 @@ expand_template(
# All disabled, but required to substituted because they are not in quotes.
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
"@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0",
"@TORCH_MLIR_ENABLE_MHLO@": "0",
"@TORCH_MLIR_ENABLE_STABLEHLO@": "0",
},
template = "lit.site.cfg.py.in",
)