mlir: bump llvm tag to 5380e3 (#856)

In addition to updating the llvm-project submodule, this patch also:

1. updates shape functions and tests so that `func` and `call`
   operations refer to the `func` dialect
2. avoid duplicate registration of dialects
pull/861/head snapshot-20220516.455
Ashay Rane 2022-05-16 12:54:35 -07:00 committed by GitHub
parent cfc1a6515c
commit bb52a460cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1368 additions and 1369 deletions

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s
// -----
// CHECK-LABEL: func @scan_1d_inclusive(
// CHECK-LABEL: func.func @scan_1d_inclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
@ -16,7 +16,7 @@
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
@ -27,7 +27,7 @@ func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
}
// -----
// CHECK-LABEL: func @scan_1d_exclusive(
// CHECK-LABEL: func.func @scan_1d_exclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
@ -44,7 +44,7 @@ func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
@ -55,7 +55,7 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
}
// -----
// CHECK-LABEL: func @scatter_update_scalar_1D(
// CHECK-LABEL: func.func @scatter_update_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
@ -71,7 +71,7 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
// CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_update_scalar_1D(
func.func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)
@ -83,7 +83,7 @@ func @scatter_update_scalar_1D(
return %0 : tensor<8xi32>
}
// CHECK-LABEL: func @scatter_add_scalar_1D(
// CHECK-LABEL: func.func @scatter_add_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
@ -101,7 +101,7 @@ func @scatter_update_scalar_1D(
// CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_add_scalar_1D(
func.func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true)

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @tensor.cast(
func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
// CHECK-LABEL: func.func @tensor.cast(
func.func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
%init = linalg.init_tensor [128] : tensor<128xi32>
%c0 = linalg.init_tensor [] : tensor<i32>

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
@ -10,7 +10,7 @@ func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
}
return
}
// CHECK-LABEL: func @scan_1d_inclusive
// CHECK-LABEL: func.func @scan_1d_inclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
@ -33,7 +33,7 @@ func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
// -----
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
func.func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(false)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
@ -43,7 +43,7 @@ func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
}
return
}
// CHECK-LABEL: func @scan_1d_exclusive
// CHECK-LABEL: func.func @scan_1d_exclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
@ -66,7 +66,7 @@ func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
// -----
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
%t0 = memref.alloc() : memref<32xi32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
@ -76,7 +76,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
}
return
}
// CHECK-LABEL: func @scan_2d
// CHECK-LABEL: func.func @scan_2d
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
@ -102,7 +102,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
// -----
func @scatter_update_scalar_1D(
func.func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
@ -113,7 +113,7 @@ func @scatter_update_scalar_1D(
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_1D
// CHECK-LABEL: func.func @scatter_update_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -128,7 +128,7 @@ func @scatter_update_scalar_1D(
// -----
func @scatter_add_scalar_2D(
func.func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
@ -140,7 +140,7 @@ func @scatter_add_scalar_2D(
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_2D
// CHECK-LABEL: func.func @scatter_add_scalar_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -159,7 +159,7 @@ func @scatter_add_scalar_2D(
// -----
func @scatter_update_slice_2D(
func.func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
@ -170,7 +170,7 @@ func @scatter_update_slice_2D(
}
return
}
// CHECK: func @scatter_update_slice_2D
// CHECK: func.func @scatter_update_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -189,7 +189,7 @@ func @scatter_update_slice_2D(
// -----
func @scatter_add_scalar_1D(
func.func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
@ -201,7 +201,7 @@ func @scatter_add_scalar_1D(
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_1D
// CHECK-LABEL: func.func @scatter_add_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -218,7 +218,7 @@ func @scatter_add_scalar_1D(
// -----
func @scatter_add_slice_2D(
func.func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
@ -230,7 +230,7 @@ func @scatter_add_slice_2D(
}
return
}
// CHECK: func @scatter_add_slice_2D
// CHECK: func.func @scatter_add_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -248,7 +248,7 @@ func @scatter_add_slice_2D(
// -----
func @scatter_update_scalar_dynamic_1D(
func.func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
@ -259,7 +259,7 @@ func @scatter_update_scalar_dynamic_1D(
}
return
}
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
// CHECK-LABEL: func.func @scatter_update_scalar_dynamic_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -274,7 +274,7 @@ func @scatter_update_scalar_dynamic_1D(
// -----
func @scatter_add_scalar_dynamic_2D(
func.func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
@ -286,7 +286,7 @@ func @scatter_add_scalar_dynamic_2D(
}
return
}
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
// CHECK-LABEL: func.func @scatter_add_scalar_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -305,7 +305,7 @@ func @scatter_add_scalar_dynamic_2D(
// -----
func @scatter_update_slice_dynamic_2D(
func.func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
tm_tensor.scatter unique_indices(true)
@ -316,7 +316,7 @@ func @scatter_update_slice_dynamic_2D(
}
return
}
// CHECK: func @scatter_update_slice_dynamic_2D
// CHECK: func.func @scatter_update_slice_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -333,7 +333,7 @@ func @scatter_update_slice_dynamic_2D(
// -----
func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
tm_tensor.scatter
unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
@ -344,7 +344,7 @@ func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>,
return
}
// CHECK-LABEL: func @scatter_partial_slices
// CHECK-LABEL: func.func @scatter_partial_slices
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-dialects-opt -split-input-file -verify-diagnostics %s
func @scatter_mixed_tensor_memref(
func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -16,7 +16,7 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_mixed_tensor_memref(
func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -32,7 +32,7 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_extra_outputs(
func.func @scatter_extra_outputs(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// expected-error @+1 {{expected number of outputs to be same as the number of results}}
@ -48,7 +48,7 @@ func @scatter_extra_outputs(
// -----
func @scatter_mixed_tensor_memref(
func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -64,7 +64,7 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_output_type_mismatch(
func.func @scatter_output_type_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<4x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
@ -80,7 +80,7 @@ func @scatter_output_type_mismatch(
// -----
func @scatter_mixed_tensor_memref(
func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
@ -96,7 +96,7 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_mixed_tensor_memref(
func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
@ -112,7 +112,7 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_dim_mismatch(
func.func @scatter_dim_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
@ -128,7 +128,7 @@ func @scatter_dim_mismatch(
// -----
func @scatter_dim_mismatch(
func.func @scatter_dim_mismatch(
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
@ -144,7 +144,7 @@ func @scatter_dim_mismatch(
// -----
func @scatter_dim_mismatch(
func.func @scatter_dim_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{op update value rank exceeds the rank of the original value}}
@ -160,7 +160,7 @@ func @scatter_dim_mismatch(
// -----
func @scatter_dim_mismatch(
func.func @scatter_dim_mismatch(
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
@ -176,7 +176,7 @@ func @scatter_dim_mismatch(
// -----
func @scatter_region_type_mismatch(
func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{expected region to have scalar argument of integer or float types}}
@ -193,7 +193,7 @@ func @scatter_region_type_mismatch(
// -----
func @scatter_region_type_mismatch(
func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
@ -210,7 +210,7 @@ func @scatter_region_type_mismatch(
// -----
func @scatter_region_type_mismatch(
func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
@ -227,7 +227,7 @@ func @scatter_region_type_mismatch(
// -----
func @scatter_region_type_mismatch(
func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
@ -244,7 +244,7 @@ func @scatter_region_type_mismatch(
// -----
func @scatter_region_type_mismatch(
func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected region to have two arguments}}
@ -261,7 +261,7 @@ func @scatter_region_type_mismatch(
// -----
func @scatter_yield_mismatch(
func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
@ -278,7 +278,7 @@ func @scatter_yield_mismatch(
// -----
func @scatter_yield_mismatch(
func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true)
@ -295,7 +295,7 @@ func @scatter_yield_mismatch(
// -----
func @scatter_index_depth_dynamic(
func.func @scatter_index_depth_dynamic(
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected index depth is static}}
@ -312,7 +312,7 @@ func @scatter_index_depth_dynamic(
// -----
func @scatter_original_rank_mismatch(
func.func @scatter_original_rank_mismatch(
%update : tensor<?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{op index depth and update value does not cover rank of original value}}

@ -1 +1 @@
Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067
Subproject commit 5380e30e047bbac9b2cceb69162eb8db1e1a7abf

File diff suppressed because it is too large Load Diff

View File

@ -116,7 +116,6 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
unknownLoc(mlirLocationUnknownGet(context)) {
// TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
torchMlirRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context);

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.mm$basic(
// CHECK-LABEL: func.func @torch.aten.mm$basic(
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -22,7 +22,7 @@
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
return %0 : !torch.vtensor<[?,2],f32>
}
@ -30,7 +30,7 @@ func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
// -----
// If the operands are missing dtype, we cannot lower it.
func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
@ -40,7 +40,7 @@ func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torc
// Correctly handle the case that operands are statically the wrong rank
// (rank 1 vs rank 2 expected for matmul.)
func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
@ -49,7 +49,7 @@ func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1:
// -----
// If the result is missing dtype, we cannot lower it.
func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
func.func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %0 : !torch.vtensor
@ -57,20 +57,20 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
// -----
// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
func.func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank
// CHECK-LABEL: func.func @torch.aten.Int.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -88,27 +88,27 @@ func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
func.func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank
// CHECK-LABEL: func.func @torch.aten.Float.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
func.func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
return %0 : !torch.float
}
// -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank
// CHECK-LABEL: func.func @torch.aten.Float.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -126,27 +126,27 @@ func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
func.func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
return %0 : !torch.float
}
// -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
// CHECK-LABEL: func.func @torch.aten.Bool.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RES]] : !torch.bool
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
func.func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool
return %0 : !torch.bool
}
// -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank
// CHECK-LABEL: func.func @torch.aten.Bool.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -164,59 +164,59 @@ func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.b
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1>
// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RET]] : !torch.bool
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
func.func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
return %0 : !torch.bool
}
// -----
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
// CHECK: func.func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
// CHECK: %[[FILLVEC:.*]] = linalg.fill ins(%[[INI64]] : i64) outs(%[[NEWVEC]] : tensor<i64>) -> tensor<i64>
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
func.func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
%0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func @torch.tensor_static_info_cast$basic(
// CHECK-LABEL: func.func @torch.tensor_static_info_cast$basic(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> {
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[VALUE_T]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
// CHECK: %[[T_CAST:.*]] = tensor.cast %[[T]] : tensor<?xf32> to tensor<4xf32>
// CHECK: %[[VALUE_T_CAST:.*]] = torch_c.from_builtin_tensor %[[T_CAST]] : tensor<4xf32> -> !torch.vtensor<[4],f32>
// CHECK: return %[[VALUE_T_CAST]] : !torch.vtensor<[4],f32>
func @torch.tensor_static_info_cast$basic(%t: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> {
func.func @torch.tensor_static_info_cast$basic(%t: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> {
%t_cast = torch.tensor_static_info_cast %t : !torch.vtensor<[?],f32> to !torch.vtensor<[4],f32>
return %t_cast : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.neg
// CHECK-LABEL: func.func @torch.aten.neg
// CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f32
// CHECK-NEXT: linalg.yield %[[NEG]] : f32
// CHECK-NEXT: } -> tensor<?x?xf32>
func @torch.aten.neg(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.neg(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.neg.bf16
// CHECK-LABEL: func.func @torch.aten.neg.bf16
// CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: bf16, %{{.*}}: bf16):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : bf16
// CHECK-NEXT: linalg.yield %[[NEG]] : bf16
// CHECK-NEXT: } -> tensor<?x?xbf16>
func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vtensor<[?,?],bf16> {
func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vtensor<[?,?],bf16> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16>
return %0 : !torch.vtensor<[?,?],bf16>
}

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @elementwise$unary(
// CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
@ -14,12 +14,12 @@
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: }
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @elementwise$binary(
// CHECK-LABEL: func.func @elementwise$binary(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -41,23 +41,23 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// CHECK-LABEL: func @elementwise$ternary(
// CHECK-LABEL: func.func @elementwise$ternary(
// CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.lerp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// CHECK-LABEL: func @elementwise$with_scalar_capture(
// CHECK-LABEL: func.func @elementwise$with_scalar_capture(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[C1:.*]] = torch.constant.int 1
@ -69,18 +69,18 @@ func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK: %[[RES:.*]] = arith.addf %[[LHS]], %[[SCALED]] : f32
// CHECK: linalg.yield %[[RES]] : f32
// CHECK: } -> tensor<?xf32>
func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
// CHECK-LABEL: func @elementwise$static_1(
// CHECK-LABEL: func.func @elementwise$static_1(
// CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0) -> (d0)>,
// CHECK-SAME: affine_map<(d0) -> (0)>,
// CHECK-SAME: affine_map<(d0) -> (d0)>]
func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> {
func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> {
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32>
}

View File

@ -2,7 +2,7 @@
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$basic(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
@ -10,7 +10,7 @@
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
func.func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
@ -19,7 +19,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
@ -27,7 +27,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
func.func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int-5 = torch.constant.int -5
%int-3 = torch.constant.int -3
%0 = torch.aten.flatten.using_ints %arg0, %int-5, %int-3 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
@ -36,7 +36,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
@ -44,7 +44,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int2 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
@ -53,7 +53,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
@ -61,7 +61,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
func.func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%0 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12],f32>
@ -70,14 +70,14 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
func.func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>

View File

@ -3,61 +3,61 @@
// -----
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic_negative(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[1,2,3,4],f32>
return %0 : !torch.vtensor<[1,2,3,4],f32>
}
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
%int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,4,1],f32>
return %0 : !torch.vtensor<[2,3,4,1],f32>
}
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
func.func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
%int2 = torch.constant.int 2
%0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,1,4],f32>
return %0 : !torch.vtensor<[2,3,1,4],f32>

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-scf | FileCheck %s
// CHECK-LABEL: func @torch.prim.if(
// CHECK-LABEL: func.func @torch.prim.if(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
@ -14,7 +14,7 @@
// CHECK: }
// CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]]
// CHECK: return %[[VAL_7]] : !torch.int
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.If %arg0 -> (!torch.int) {
@ -25,7 +25,7 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
return %0 : !torch.int
}
// CHECK-LABEL: func @aten.prim.if$nested(
// CHECK-LABEL: func.func @aten.prim.if$nested(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]]
@ -48,7 +48,7 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
// CHECK: }
// CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]]
// CHECK: return %[[VAL_13]] : !torch.int
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
func.func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
@ -65,7 +65,7 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.prim.loop$while
// CHECK-LABEL: func.func @torch.prim.loop$while
// CHECK-SAME: (%[[ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[TORCH_FLOAT_VAL:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL]]
@ -86,7 +86,7 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
// CHECK-NEXT: }
// CHECK-NEXT: %[[TORCH_LOOP:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[TORCH_LOOP]] : !torch.float
func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
func.func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
%float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.aten.lt.float_int %float3.200000e00, %arg0 : !torch.float, !torch.int -> !torch.bool
@ -99,7 +99,7 @@ func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
return %1 : !torch.float
}
// CHECK-LABEL: func @torch.prim.loop$while_with_multiple_values
// CHECK-LABEL: func.func @torch.prim.loop$while_with_multiple_values
// CHECK-SAME: () -> (!torch.float, !torch.float) {
// CHECK: %[[TORCH_FLOAT_VAL_0:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_0]]
@ -127,7 +127,7 @@ func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
// CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
// CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float
func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) {
func.func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) {
%float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807
%float9.0 = torch.constant.float 9.0
@ -143,7 +143,7 @@ func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.floa
return %1#0, %1#1 : !torch.float, !torch.float
}
// CHECK-LABEL: func @torch.prim.Loop$for
// CHECK-LABEL: func.func @torch.prim.Loop$for
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
@ -164,7 +164,7 @@ func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.floa
// CHECK-NEXT: %[[RETURN:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[RETURN]] : !torch.float
// CHECK-NEXT: }
func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
func.func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
%true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00
%0 = torch.prim.Loop %arg0, %true, init(%float0.000000e00) {
@ -175,7 +175,7 @@ func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.prim.Loop$for_with_multiple_results
// CHECK-LABEL: func.func @torch.prim.Loop$for_with_multiple_results
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> (!torch.float, !torch.float) {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true
@ -202,7 +202,7 @@ func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
// CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float
// CHECK-NEXT: }
func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) {
func.func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) {
%true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00
%float9.0 = torch.constant.float 9.0

View File

@ -1,19 +1,19 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-std | FileCheck %s
// CHECK-LABEL: func @torch.aten.dim(
// CHECK-LABEL: func.func @torch.aten.dim(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: %[[RANK:.*]] = tensor.rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
// CHECK: %[[RANK_I64:.*]] = arith.index_cast %[[RANK]] : index to i64
// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int
func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
func.func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.runtime.assert(
// CHECK-LABEL: func.func @torch.runtime.assert(
// CHECK-SAME: %[[X:.*]]: !torch.int,
// CHECK-SAME: %[[Y:.*]]: !torch.int) {
// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]]
@ -21,13 +21,13 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64
// CHECK: assert %[[CMP]], "x must not be equal to y"
// CHECK: return
func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
func.func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
torch.runtime.assert %0, "x must not be equal to y"
return
}
// CHECK-LABEL: func @torch.aten.ne.int(
// CHECK-LABEL: func.func @torch.aten.ne.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -35,12 +35,12 @@ func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.int(
// CHECK-LABEL: func.func @torch.aten.eq.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -48,12 +48,12 @@ func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.eq.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.gt.int(
// CHECK-LABEL: func.func @torch.aten.gt.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -61,48 +61,48 @@ func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool {
// CHECK-LABEL: func.func @torch.constant.bool() -> !torch.bool {
// CHECK: %[[CST:.*]] = arith.constant true
// CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]]
// CHECK: return %[[BOOL]] : !torch.bool
func @torch.constant.bool() -> !torch.bool {
func.func @torch.constant.bool() -> !torch.bool {
%true = torch.constant.bool true
return %true : !torch.bool
}
// CHECK-LABEL: func @torch.constant.float() -> !torch.float {
// CHECK-LABEL: func.func @torch.constant.float() -> !torch.float {
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f64
// CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]]
// CHECK: return %[[FLOAT]] : !torch.float
func @torch.constant.float() -> !torch.float {
func.func @torch.constant.float() -> !torch.float {
%float = torch.constant.float 1.000000e+00
return %float : !torch.float
}
// CHECK-LABEL: func @torch.constant.int() -> !torch.int {
// CHECK-LABEL: func.func @torch.constant.int() -> !torch.int {
// CHECK: %[[CST:.*]] = arith.constant 1 : i64
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]]
// CHECK: return %[[INT]] : !torch.int
func @torch.constant.int() -> !torch.int {
func.func @torch.constant.int() -> !torch.int {
%int1 = torch.constant.int 1
return %int1 : !torch.int
}
// CHECK-LABEL: func @torch.aten.add.int(
// CHECK-LABEL: func.func @torch.aten.add.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -110,12 +110,12 @@ func @torch.constant.int() -> !torch.int {
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
func.func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.sub.int(
// CHECK-LABEL: func.func @torch.aten.sub.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -123,12 +123,12 @@ func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
func.func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.sub.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.sub.float(
// CHECK-LABEL: func.func @torch.aten.sub.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -136,12 +136,12 @@ func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.subf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
// CHECK: return %[[OUT:.*]] : !torch.float
func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
func.func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.sub.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.mul.int(
// CHECK-LABEL: func.func @torch.aten.mul.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -149,12 +149,12 @@ func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[MUL:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.div.float(
// CHECK-LABEL: func.func @torch.aten.div.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -162,12 +162,12 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
// CHECK: return %[[OUT:.*]] : !torch.float
func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
func.func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @torch.aten.ge.float(
// CHECK-LABEL: func.func @torch.aten.ge.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -175,12 +175,12 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool {
func.func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool {
%0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ge.float_int(
// CHECK-LABEL: func.func @torch.aten.ge.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -189,12 +189,12 @@ func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bo
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ne.float_int(
// CHECK-LABEL: func.func @torch.aten.ne.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -203,24 +203,24 @@ func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.
// CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ceil.float(
// CHECK-LABEL: func.func @torch.aten.ceil.float(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64
// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]]
// CHECK: return %[[OUT]] : !torch.int
func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
func.func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
%0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.aten.gt.float_int(
// CHECK-LABEL: func.func @torch.aten.gt.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -229,7 +229,7 @@ func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
// CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
func.func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool
}

View File

@ -1,38 +1,38 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.tanh$basic(
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.sigmoid$basic(
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sigmoid"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.relu$basic(
// CHECK-LABEL: func.func @torch.aten.relu$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
@ -40,100 +40,100 @@ func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<
// -----
// CHECK-LABEL: func @torch.aten.log$basic(
// CHECK-LABEL: func.func @torch.aten.log$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.log"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.exp$basic(
// CHECK-LABEL: func.func @torch.aten.exp$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.exp"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.neg$basic(
// CHECK-LABEL: func.func @torch.aten.neg$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.negate"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.floor$basic(
// CHECK-LABEL: func.func @torch.aten.floor$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.floor"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.floor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.bitwise_not$basic(
// CHECK-LABEL: func.func @torch.aten.bitwise_not$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.bitwise_not"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.ceil$basic(
// CHECK-LABEL: func.func @torch.aten.ceil$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.reciprocal$basic(
// CHECK-LABEL: func.func @torch.aten.reciprocal$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.add$basic(
// CHECK-LABEL: func.func @torch.aten.add$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -145,7 +145,7 @@ func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
@ -153,7 +153,7 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// -----
// CHECK-LABEL: func @torch.aten.sub$basic(
// CHECK-LABEL: func.func @torch.aten.sub$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -165,7 +165,7 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
@ -173,7 +173,7 @@ func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// -----
// CHECK-LABEL: func @torch.aten.mul$basic(
// CHECK-LABEL: func.func @torch.aten.mul$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -181,14 +181,14 @@ func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.div$basic(
// CHECK-LABEL: func.func @torch.aten.div$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -197,14 +197,14 @@ func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: func @test_reduce_mean_dim$basic(
// CHECK-LABEL: func.func @test_reduce_mean_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
@ -217,7 +217,7 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%dim0 = torch.constant.int 0
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%keepdims = torch.constant.bool false
@ -228,7 +228,7 @@ func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// -----
// CHECK-LABEL: func @test_reduce_sum_dims$basic(
// CHECK-LABEL: func.func @test_reduce_sum_dims$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
@ -239,7 +239,7 @@ func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0
@ -250,7 +250,7 @@ func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// -----
// CHECK-LABEL: func @test_reduce_sum$basic(
// CHECK-LABEL: func.func @test_reduce_sum$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
@ -261,7 +261,7 @@ func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
%none = torch.constant.none
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
@ -269,7 +269,7 @@ func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vten
// -----
// CHECK-LABEL: func @test_reduce_all$basic(
// CHECK-LABEL: func.func @test_reduce_all$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
@ -279,14 +279,14 @@ func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1>
}
// -----
// CHECK-LABEL: func @test_reduce_any_dim$basic(
// CHECK-LABEL: func.func @test_reduce_any_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
@ -295,7 +295,7 @@ func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1>
@ -304,7 +304,7 @@ func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.v
// -----
// CHECK-LABEL: func @test_reduce_any$basic(
// CHECK-LABEL: func.func @test_reduce_any$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
@ -314,28 +314,28 @@ func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.v
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.rsqrt$basic(
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.maximum$basic(
// CHECK-LABEL: func.func @torch.aten.maximum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -344,14 +344,14 @@ func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.minimum$basic(
// CHECK-LABEL: func.func @torch.aten.minimum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -360,14 +360,14 @@ func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic(
// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -376,7 +376,7 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 3.123400e+00
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
@ -384,7 +384,7 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
// -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -396,7 +396,7 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00
%alpha = torch.constant.float 6.432100e+00
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32>
@ -405,7 +405,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -417,7 +417,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00
%alpha = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32>
@ -426,7 +426,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.gt.Tensor$basic(
// CHECK-LABEL: func.func @torch.aten.gt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -435,14 +435,14 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
func.func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.lt.Tensor$basic(
// CHECK-LABEL: func.func @torch.aten.lt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -451,14 +451,14 @@ func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
func.func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.eq.Tensor$basic(
// CHECK-LABEL: func.func @torch.aten.eq.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -467,14 +467,14 @@ func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.reshape$basic(
// CHECK-LABEL: func.func @torch.aten.reshape$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
@ -483,7 +483,7 @@ func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
// CHECK: }
func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
%dim0 = torch.constant.int -1
%shape = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%0 = torch.aten.reshape %arg0, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
@ -492,7 +492,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.native_batch_norm$basic(
// CHECK-LABEL: func.func @torch.aten.native_batch_norm$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32>
@ -515,7 +515,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32>
// CHECK: }
func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
%0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%float1.000000e-01 = torch.constant.float 1.000000e-01
@ -528,7 +528,7 @@ func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -
// -----
// CHECK-LABEL: func @forward(
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
@ -538,7 +538,7 @@ func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
%int4 = torch.constant.int 4
%int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
@ -547,7 +547,7 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
// -----
// CHECK-LABEL: func @forward(
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
@ -584,7 +584,7 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32>
// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
@ -595,7 +595,7 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
// -----
// CHECK-LABEL: func @torch.aten.ne.Tensor$basic(
// CHECK-LABEL: func.func @torch.aten.ne.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -605,14 +605,14 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @forward(
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
@ -624,7 +624,7 @@ func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> {
func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
@ -635,7 +635,7 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
// -----
// CHECK-LABEL: func @torch.aten.bitwise_and.Tensor$basic(
// CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
@ -644,14 +644,14 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: }
func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}
// -----
// CHECK-LABEL: func @torch.aten.log2$basic(
// CHECK-LABEL: func.func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
@ -661,14 +661,14 @@ func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %ar
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK-LABEL: func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
@ -678,7 +678,7 @@ func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
@ -689,7 +689,7 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// -----
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
// CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
@ -698,7 +698,7 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32>
// CHECK: }
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
%int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32>
return %0 : !torch.vtensor<[4,1,3],si32>
@ -706,14 +706,14 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.contiguous$basic(
// CHECK-LABEL: func.func @torch.aten.contiguous$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
@ -721,7 +721,7 @@ func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK-LABEL: func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none
@ -731,7 +731,7 @@ func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.v
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%none = torch.constant.none
@ -742,7 +742,7 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// -----
// CHECK-LABEL: func @torch.aten.dropout$basic(
// CHECK-LABEL: func.func @torch.aten.dropout$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00
@ -751,7 +751,7 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%float0.000000e00 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>

View File

@ -4,31 +4,31 @@ torch.class_type @c {
torch.attr "float" : !torch.float
torch.method "calls_free_function", @calls_free_function
}
// CHECK-LABEL: func private
// CHECK-LABEL: func.func private
// CHECK-SAME: @free_function$[[$MONOMORPHIZE_TAG0:.*]](
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: return %[[F]] : !torch.float
// CHECK: }
func private @free_function(%arg0: !torch.float, %arg1: !torch.nn.Module<"c">) -> !torch.float {
func.func private @free_function(%arg0: !torch.float, %arg1: !torch.nn.Module<"c">) -> !torch.float {
return %arg0 : !torch.float
}
// CHECK-LABEL: func private
// CHECK-LABEL: func.func private
// CHECK-SAME: @free_function_no_module_args$[[$MONOMORPHIZE_TAG1:.*]](
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: return %[[F]] : !torch.float
// CHECK: }
func private @free_function_no_module_args(%arg0: !torch.float) -> !torch.float {
func.func private @free_function_no_module_args(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float
}
// CHECK-LABEL: func @calls_free_function() -> !torch.float {
// CHECK-LABEL: func.func @calls_free_function() -> !torch.float {
// CHECK: %[[F1:.*]] = torch.global_slot.get @float : !torch.float
// CHECK: %[[F2:.*]] = call @free_function$[[$MONOMORPHIZE_TAG0]](%[[F1]]) : (!torch.float) -> !torch.float
// CHECK: %[[RET:.*]] = call @free_function_no_module_args$[[$MONOMORPHIZE_TAG1]](%[[F2]]) : (!torch.float) -> !torch.float
// CHECK: return %[[RET]] : !torch.float
// CHECK: }
func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> !torch.float {
func.func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
%1 = call @free_function(%0, %arg0) : (!torch.float, !torch.nn.Module<"c">) -> !torch.float
%2 = call @free_function_no_module_args(%1) : (!torch.float) -> !torch.float

View File

@ -7,28 +7,28 @@ torch.class_type @c {
torch.method "test_call", @test_call
}
// CHECK-LABEL: func @test_get() -> !torch.float {
// CHECK-LABEL: func.func @test_get() -> !torch.float {
// CHECK: %[[V:.*]] = torch.global_slot.get @float : !torch.float
// CHECK: return %[[V]] : !torch.float
func private @test_get(%arg0: !torch.nn.Module<"c">) -> !torch.float {
func.func private @test_get(%arg0: !torch.nn.Module<"c">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func @test_set(
// CHECK-LABEL: func.func @test_set(
// CHECK-SAME: %[[A:.*]]: !torch.float) {
// CHECK: torch.global_slot.set @float = %[[A]] : !torch.float
// CHECK: return
func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) {
func.func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) {
torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, !torch.float
return
}
// CHECK-LABEL: func @test_call(
// CHECK-LABEL: func.func @test_call(
// CHECK-SAME: %[[A:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (!torch.float) -> !torch.float
// CHECK: return %[[V]] : !torch.float
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
func.func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = call @test_call(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
return %0 : !torch.float
}

View File

@ -4,7 +4,7 @@ torch.class_type @parent {
torch.method "module_type_return", @module_type_return
}
func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
func.func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
// expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}}
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<nn.Module<"parent">>
return

View File

@ -10,8 +10,8 @@ torch.class_type @parent {
torch.method "method_call", @method_call
}
// CHECK-LABEL: func @get_attr_returns_module_type() -> !torch.float {
func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) -> !torch.float {
// CHECK-LABEL: func.func @get_attr_returns_module_type() -> !torch.float {
func.func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
// CHECK-NEXT: %[[V:.*]] = torch.global_slot.get @m.float : !torch.float
%1 = torch.prim.GetAttr %0["float"] : !torch.nn.Module<"child"> -> !torch.float
@ -21,15 +21,15 @@ func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) ->
return %1 : !torch.float
}
// CHECK-LABEL: func @module_type_argument(
// CHECK-LABEL: func.func @module_type_argument(
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.none {
func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"parent">, %arg2: !torch.float, %arg3: !torch.nn.Module<"parent">) -> !torch.none {
func.func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"parent">, %arg2: !torch.float, %arg3: !torch.nn.Module<"parent">) -> !torch.none {
%0 = torch.constant.none
return %0 : !torch.none
}
// CHECK-LABEL: func @method_call() -> !torch.none {
func private @method_call(%arg0: !torch.nn.Module<"parent">) -> !torch.none {
// CHECK-LABEL: func.func @method_call() -> !torch.none {
func.func private @method_call(%arg0: !torch.nn.Module<"parent">) -> !torch.none {
// CHECK-NEXT: %[[C:.*]] = torch.constant.float 4.300000e+01
%c = torch.constant.float 43.0
// CHECK-NEXT: %[[F:.*]] = call @module_type_argument(%[[C]]) : (!torch.float) -> !torch.none

View File

@ -1,9 +1,9 @@
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
return
}
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%5 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%6 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.Submodule.forward(%5, %6) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()

View File

@ -34,10 +34,10 @@ torch.class_type @__torch__.Submodule {
} : !torch.nn.Module<"__torch__.TestModule">
// CHECK-LABEL: func @forward() {
// CHECK-LABEL: func.func @forward() {
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG0:.*]]() : () -> ()
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG1:.*]]() : () -> ()
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.free_function(%4, %5) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
@ -48,27 +48,27 @@ func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.Te
}
// s1 called first, then s2
// CHECK-LABEL: func private
// CHECK-LABEL: func.func private
// CHECK-SAME @__torch__.free_function$[[$MONOMORPHIZE_TAG0]]() {
// CHECK: call @s1.forward() : () -> ()
// CHECK: call @s2.forward() : () -> ()
// s2 called first, then s1
// CHECK-LABEL: func private
// CHECK-LABEL: func.func private
// CHECK-SAME: @__torch__.free_function$[[$MONOMORPHIZE_TAG1]]() {
// CHECK: call @s2.forward() : () -> ()
// CHECK: call @s1.forward() : () -> ()
func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
func.func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
call @__torch__.Submodule.forward(%arg0) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
call @__torch__.Submodule.forward(%arg1) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
return
}
// CHECK-LABEL: func private @s2.forward() {
// CHECK-LABEL: func.func private @s2.forward() {
// CHECK: return
// CHECK-LABEL: func private @s1.forward() {
// CHECK-LABEL: func.func private @s1.forward() {
// CHECK: return
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
return
}

View File

@ -32,31 +32,31 @@ torch.class_type @__torch__.Submodule {
} : !torch.nn.Module<"__torch__.TestModule">
// CHECK-LABEL: func @forward() {
// CHECK-LABEL: func.func @forward() {
// CHECK: call @s1.forward() : () -> ()
// CHECK: call @s2.forward() : () -> ()
// CHECK: return
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.Submodule.forward(%4) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
call @__torch__.Submodule.forward(%5) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
return
}
// CHECK-LABEL: func private @s1.forward() {
// CHECK-LABEL: func.func private @s1.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s1.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s1.n = %[[NEWVAL]] : !torch.int
// CHECK: return
// CHECK-LABEL: func private @s2.forward() {
// CHECK-LABEL: func.func private @s2.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s2.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s2.n = %[[NEWVAL]] : !torch.int
// CHECK: return
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
%int3 = torch.constant.int 3
%5 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int
%6 = torch.aten.add.int %5, %int3 : !torch.int, !torch.int -> !torch.int

View File

@ -6,8 +6,8 @@ torch.class_type @c {
torch.method private "forward", @method
}
// CHECK: func private @forward() {
func private @method(%arg0: !torch.nn.Module<"c">) {
// CHECK: func.func private @forward() {
func.func private @method(%arg0: !torch.nn.Module<"c">) {
return
}

View File

@ -1,22 +1,22 @@
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
return %arg0 : !torch.tensor
}
// CHECK-LABEL: func @no_type_bound(
// CHECK-LABEL: func.func @no_type_bound(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[ARG]] : !torch.tensor
func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
return %arg0 : !torch.tensor
}
// CHECK-LABEL: func @call(
// CHECK-LABEL: func.func @call(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
@ -24,31 +24,31 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
// CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
%0 = call @call(%arg0) : (!torch.tensor) -> !torch.tensor
return %arg0 : !torch.tensor
}
// CHECK-LABEL: func @none_return() {
// CHECK-LABEL: func.func @none_return() {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: return
func @none_return() -> !torch.none {
func.func @none_return() -> !torch.none {
%1 = torch.constant.none
return %1 : !torch.none
}
// CHECK-LABEL: func @none_call_return() {
// CHECK-LABEL: func.func @none_call_return() {
// CHECK: call @none_return() : () -> ()
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> ()
// CHECK: return
func @none_call_return() {
func.func @none_call_return() {
%0 = call @none_return() : () -> !torch.none
"test.use"(%0) : (!torch.none) -> ()
return
}
// CHECK-LABEL: func @tuple_return(
// CHECK-LABEL: func.func @tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
@ -64,13 +64,13 @@ func @none_call_return() {
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
return %1 : !torch.tuple<tensor, tensor>
}
// CHECK-LABEL: func @call_tuple_return(
// CHECK-LABEL: func.func @call_tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
@ -92,7 +92,7 @@ func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor>

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @matmul_no_decompose
// CHECK-LABEL: func.func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
func.func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
@ -10,23 +10,23 @@ func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.
// -----
// CHECK-LABEL: func @matmul_decompose_2d
// CHECK-LABEL: func.func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
func.func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @matmul_decompose_3d(
// CHECK-LABEL: func.func @matmul_decompose_3d(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.softmax.int(
// CHECK-LABEL: func.func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -45,7 +45,7 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> {
func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> {
%dtype = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %dtype: !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
return %ret : !torch.tensor<[2,3],f32>
@ -53,7 +53,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -72,7 +72,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
@ -80,7 +80,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
}
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$dyn_shape(
// CHECK-LABEL: func.func @torch.aten.softmax.int$dyn_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -99,7 +99,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[?,?],f32>, !torch.tensor<[?,1],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[?,?],f32> to !torch.tensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.tensor<[?,?],f32>
func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
func.func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none -> !torch.tensor<[?,?],f32>
@ -107,7 +107,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
}
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$unknown_shape(
// CHECK-LABEL: func.func @torch.aten.softmax.int$unknown_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -126,7 +126,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<*,f32>, !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor<*,f32>
// CHECK: return %[[RET]] : !torch.tensor<*,f32>
func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
func.func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
%none = torch.constant.none
%dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
@ -134,7 +134,7 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
}
// -----
// CHECK-LABEL: func @torch.aten.size(
// CHECK-LABEL: func.func @torch.aten.size(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
@ -142,13 +142,13 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[SIZE]] : !torch.list<int>
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
func.func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<int>
return %0 : !torch.list<int>
}
// -----
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK-LABEL: func.func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CSTN:.*]] = torch.constant.none
// CHECK: %[[CST0:.*]] = torch.constant.int 0
@ -156,7 +156,7 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST5]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
func.func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
%int5 = torch.constant.int 5
%none = torch.constant.none
%0 = torch.aten.arange %int5, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
@ -164,7 +164,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
}
// -----
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
// CHECK-LABEL: func.func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST10:.*]] = torch.constant.int 10
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CSTN:.*]] = torch.constant.none
@ -172,7 +172,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST10]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
func.func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
%int10 = torch.constant.int 10
%int0 = torch.constant.int 0
%none = torch.constant.none
@ -181,14 +181,14 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
}
// -----
// CHECK-LABEL: func @torch.aten.argmax(
// CHECK-LABEL: func.func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
@ -196,7 +196,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
}
// -----
// CHECK-LABEL: func @torch.aten.argmax$reduceall(
// CHECK-LABEL: func.func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -207,7 +207,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] :
// CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64>
func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
func.func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
@ -215,18 +215,18 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
}
// -----
// CHECK-LABEL: func @torch.aten.square(
// CHECK-LABEL: func.func @torch.aten.square(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.var$unbiased(
// CHECK-LABEL: func.func @torch.aten.var$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -242,14 +242,14 @@ func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.var$biased(
// CHECK-LABEL: func.func @torch.aten.var$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -263,14 +263,14 @@ func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.std$unbiased(
// CHECK-LABEL: func.func @torch.aten.std$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -287,14 +287,14 @@ func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true
%0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.std$biased(
// CHECK-LABEL: func.func @torch.aten.std$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -309,20 +309,20 @@ func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$static
// CHECK-LABEL: func.func @torch.aten._unsafe_view$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
func.func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
%c1 = torch.constant.int 1
%c2 = torch.constant.int 2
%c256 = torch.constant.int 256
@ -333,14 +333,14 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
}
// -----
// CHECK-LABEL: func @torch.aten._reshape_alias$static
// CHECK-LABEL: func.func @torch.aten._reshape_alias$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1],f32>)
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._reshape_alias
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST1]]
// CHECK-NEXT: return
func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> {
func.func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> {
%int1 = torch.constant.int 1
%int32 = torch.constant.int 32
%int12 = torch.constant.int 12
@ -351,13 +351,13 @@ func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch
}
// -----
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic
// CHECK-LABEL: func.func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
func.func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
%c256 = torch.constant.int 512
%c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<int>
@ -366,7 +366,7 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
}
// -----
// CHECK-LABEL: func @torch.aten._log_softmax(
// CHECK-LABEL: func.func @torch.aten._log_softmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -386,7 +386,7 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>,
// CHECK-SAME: !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SUB1]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
func.func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32>
@ -394,7 +394,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
}
// -----
// CHECK-LABEL: func @torch.aten.bernoulli
// CHECK-LABEL: func.func @torch.aten.bernoulli
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT7:.*]] = torch.constant.int 7
@ -424,7 +424,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func.func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
@ -432,7 +432,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.float
// CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.float
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
@ -463,7 +463,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func.func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%prob = torch.constant.float 4.000000e-01
%0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
@ -472,7 +472,7 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
}
// -----
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.Tensor(
// CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.Tensor(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -502,7 +502,7 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func.func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
@ -510,7 +510,7 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar
}
// -----
// CHECK-LABEL: func @torch.aten.rand_like(
// CHECK-LABEL: func.func @torch.aten.rand_like(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[NONE_0:.*]] = torch.constant.none
@ -528,7 +528,7 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar
// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
func.func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%int6 = torch.constant.int 6
%none = torch.constant.none
%0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
@ -537,7 +537,7 @@ func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.aten.select.int(
// CHECK-LABEL: func.func @torch.aten.select.int(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1
@ -547,14 +547,14 @@ func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK: %[[SELECT:.*]] = torch.aten.squeeze.dim %[[SLICE]], %[[CST0]] :
// CHECK-SAME: !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: return %[[SELECT]] : !torch.vtensor<[?],si64>
func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
func.func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
%int0 = torch.constant.int 0
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64>
}
// -----
// CHECK-LABEL: func @torch.aten.hardsigmoid(
// CHECK-LABEL: func.func @torch.aten.hardsigmoid(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST2:.*]] = torch.constant.int 3
@ -576,13 +576,13 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.hardswish(
// CHECK-LABEL: func.func @torch.aten.hardswish(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -599,13 +599,13 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32>
func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.hardtanh(
// CHECK-LABEL: func.func @torch.aten.hardtanh(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float,
// CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> {
@ -622,13 +622,13 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
func.func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
%0 = torch.aten.hardtanh %arg0, %min, %max : !torch.vtensor<[?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.new_zeros
// CHECK-LABEL: func.func @torch.aten.new_zeros
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -638,7 +638,7 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
// CHECK: }
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
func.func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -648,7 +648,7 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
}
// -----
// CHECK-LABEL: func @torch.aten.new_ones
// CHECK-LABEL: func.func @torch.aten.new_ones
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -658,7 +658,7 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
// CHECK: }
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
func.func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
%none = torch.constant.none
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
@ -668,18 +668,18 @@ func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[
}
// -----
// CHECK-LABEL: func @torch.aten.silu(
// CHECK-LABEL: func.func @torch.aten.silu(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// CHECK: %[[SIGMOID:.*]] = torch.aten.sigmoid %[[INP]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[SIGMOID]], %[[INP]] : !torch.vtensor, !torch.vtensor<[?,?],f32> -> !torch.vtensor
// CHECK: return %[[MUL]] : !torch.vtensor
func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.vtensor {
func.func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.vtensor {
%0 = torch.aten.silu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.aten.full
// CHECK-LABEL: func.func @torch.aten.full
// CHECK-SAME: () -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
// CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -690,7 +690,7 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
func.func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
%float5.000000e00 = torch.constant.float 5.000000e+00
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
@ -701,7 +701,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
}
// -----
// CHECK-LABEL: func @torch.aten.full_like(
// CHECK-LABEL: func.func @torch.aten.full_like(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -713,7 +713,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int5 = torch.constant.int 5
%none = torch.constant.none
%0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
@ -721,7 +721,7 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
}
// -----
// CHECK-LABEL: func @torch.aten.index_put(
// CHECK-LABEL: func.func @torch.aten.index_put(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> {
@ -729,14 +729,14 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> {
func.func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> {
%indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
%0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.expand_as(
// CHECK-LABEL: func.func @torch.aten.expand_as(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int
@ -747,13 +747,13 @@ func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtens
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten._to_copy(
// CHECK-LABEL: func.func @torch.aten._to_copy(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -767,7 +767,7 @@ func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vte
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.copy %[[EMPTY]], %[[INP]], %[[FALSE]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
func.func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten._to_copy %arg0, %none, %none, %none, %none, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
@ -775,12 +775,12 @@ func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<
}
// -----
// CHECK-LABEL: func @torch.aten.dropout$eval(
// CHECK-LABEL: func.func @torch.aten.dropout$eval(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[PROB:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[TRAIN:.*]] = torch.constant.bool false
// CHECK: return %[[INP:.*]] : !torch.vtensor<[?,?],f32>
func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float1.000000e-01 = torch.constant.float 1.000000e-01
%false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float1.000000e-01, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
@ -788,8 +788,8 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso
}
// -----
// CHECK-LABEL: func @torch.aten.dropout$train(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-LABEL: func.func @torch.aten.dropout$train(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[PROB:.*]] = torch.constant.float 3.000000e-01
// CHECK: %[[TRAIN:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -821,7 +821,7 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso
// CHECK: %[[MASK_INP:.*]] = torch.aten.mul.Tensor %[[BOOL_MASK]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[MASK_INP]], %[[ONEMINUSP]] : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float3.000000e-01 = torch.constant.float 3.000000e-01
%true = torch.constant.bool true
%0 = torch.aten.dropout %arg0, %float3.000000e-01, %true : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
@ -829,18 +829,18 @@ func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtens
}
// -----
// CHECK-LABEL: func @torch.valsem.aten.zero(
// CHECK-LABEL: func.func @torch.valsem.aten.zero(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
func.func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.valsem.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.new_empty
// CHECK-LABEL: func.func @torch.aten.new_empty
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -850,7 +850,7 @@ func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
func.func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
@ -860,7 +860,7 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
}
// -----
// CHECK-LABEL: func @torch.aten.where.Scalar(
// CHECK-LABEL: func.func @torch.aten.where.Scalar(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
@ -874,7 +874,7 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
func.func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
%cst8 = torch.constant.float 8.000000e+00
%cst4 = torch.constant.float 4.000000e+00
%0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32>
@ -882,7 +882,7 @@ func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtens
}
// -----
// CHECK-LABEL: func @torch.aten.where.ScalarSelf(
// CHECK-LABEL: func.func @torch.aten.where.ScalarSelf(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
@ -891,14 +891,14 @@ func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtens
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
func.func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64>
}
// -----
// CHECK-LABEL: func @torch.aten.where.ScalarOther(
// CHECK-LABEL: func.func @torch.aten.where.ScalarOther(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
@ -907,21 +907,21 @@ func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !tor
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
func.func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64>
}
// -----
// CHECK-LABEL: func @torch.aten.pad
// CHECK-LABEL: func.func @torch.aten.pad
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
// CHECK-NOT: torch.aten.pad
// CHECK: %[[STRING:.*]] = torch.constant.str "constant"
// CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]]
// CHECK-NEXT: return %[[PAD_ND]]
func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
func.func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
@ -933,7 +933,7 @@ func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) ->
}
// -----
// CHECK-LABEL: func @torch.aten.to.dtype_layout(
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -941,7 +941,7 @@ func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) ->
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[OUT:.*]] = torch.aten.to.dtype %[[SELF]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f64>
func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> {
func.func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> {
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0

View File

@ -1,11 +1,11 @@
// RUN: torch-mlir-opt -torch-drop-shape-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[ERASED]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.shape.calculate {

View File

@ -16,8 +16,8 @@ torch.global_slot "private" @mutated : !torch.tensor {
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
// CHECK-LABEL: func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
// Inlined.
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
%0 = torch.global_slot.get @readonly : !torch.tensor

View File

@ -1,20 +1,20 @@
// RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// CHECK-LABEL: func @torch.copy.tensor$basic(
// CHECK-LABEL: func.func @torch.copy.tensor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
func.func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.copy.to_vtensor %0 : !torch.vtensor
%2 = torch.copy.to_vtensor %0 : !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @one_mutation_in_a_block(
// CHECK-LABEL: func.func @one_mutation_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
func.func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
@ -22,11 +22,11 @@ func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @multiple_mutations_in_a_block(
// CHECK-LABEL: func.func @multiple_mutations_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
func.func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// The mutable tensor we are overwriting.
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
@ -45,12 +45,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @mutation_followed_by_view_like_ops(
// CHECK-LABEL: func.func @mutation_followed_by_view_like_ops(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
func.func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
@ -59,10 +59,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
return %value_result : !torch.vtensor
}
// CHECK-LABEL: func @mutation_of_view_like_op_result(
// CHECK-LABEL: func.func @mutation_of_view_like_op_result(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: return %[[OVERWRITER]] : !torch.vtensor
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
func.func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor
@ -70,20 +70,20 @@ func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !to
return %result : !torch.vtensor
}
// CHECK-LABEL: func @value_tensor_used_after_copy_was_mutated(
// CHECK-LABEL: func.func @value_tensor_used_after_copy_was_mutated(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor,
// CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor
func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
func.func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor
return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK-LABEL: func.func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> ()
@ -92,9 +92,9 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch
}
// We don't yet handle nontrivial cases involving control flow.
// CHECK-LABEL: func @unimplemented_control_flow(
// CHECK-LABEL: func.func @unimplemented_control_flow(
// CHECK: torch.copy.to_vtensor
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
func.func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () {
@ -107,33 +107,33 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @non_value_tensor_returned(
// CHECK-LABEL: func.func @non_value_tensor_returned(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor {
func.func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
return %t : !torch.tensor
}
// CHECK-LABEL: func @non_value_tensor_returned$with_overwrite(
// CHECK-LABEL: func.func @non_value_tensor_returned$with_overwrite(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %{{.*}}: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[RESULT:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: return %[[RESULT]] : !torch.tensor
func @non_value_tensor_returned$with_overwrite(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
func.func @non_value_tensor_returned$with_overwrite(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
%2 = torch.copy.to_tensor %arg1 : !torch.tensor
torch.overwrite.tensor.contents %arg0 overwrites %2 : !torch.vtensor, !torch.tensor
return %2 : !torch.tensor
}
// CHECK-LABEL: func @non_value_tensor_returned$return_from_multiple_slices(
// CHECK-LABEL: func.func @non_value_tensor_returned$return_from_multiple_slices(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) {
// CHECK: %[[NON_VALUE_TENSOR0:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: %[[NON_VALUE_TENSOR1:.*]] = torch.copy.to_tensor %[[ARG1]] : !torch.tensor
// CHECK: return %[[NON_VALUE_TENSOR0]], %[[ARG0]], %[[NON_VALUE_TENSOR1]] : !torch.tensor, !torch.vtensor, !torch.tensor
func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) {
func.func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
// Make a vtensor copy and return that, just to have a load-bearing use.
// This test mainly checks the rewriting of the non-value tensor returns
@ -143,12 +143,12 @@ func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtenso
return %0, %1, %2 : !torch.tensor, !torch.vtensor, !torch.tensor
}
// CHECK-LABEL: func @viewlike$basic_unsqueeze(
// CHECK-LABEL: func.func @viewlike$basic_unsqueeze(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[UNSQUEEZE]] : !torch.vtensor
func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -156,13 +156,13 @@ func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @viewlike$basic_flatten(
// CHECK-LABEL: func.func @viewlike$basic_flatten(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[ARG]], %[[INT0]], %[[INTM1]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor
// CHECK: return %[[FLATTEN]] : !torch.vtensor
func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor {
%start = torch.constant.int 0
%end = torch.constant.int -1
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
@ -171,13 +171,13 @@ func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @viewlike$transitive(
// CHECK-LABEL: func.func @viewlike$transitive(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[UNSQUEEZE1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[UNSQUEEZE1]] : !torch.vtensor
func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -186,14 +186,14 @@ func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor {
return %3 : !torch.vtensor
}
// CHECK-LABEL: func @viewlike$transitive_tree(
// CHECK-LABEL: func.func @viewlike$transitive_tree(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET0:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.vtensor, !torch.vtensor
func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
func.func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
// %1 has two users.
@ -208,11 +208,11 @@ func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch
return %3, %5 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @viewlike$unmodeled_op(
// CHECK-LABEL: func.func @viewlike$unmodeled_op(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze {{.*}} : !torch.tensor, !torch.int -> !torch.tensor
// CHECK: "some.op"(%[[UNSQUEEZE]]) : (!torch.tensor) -> ()
func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -221,23 +221,23 @@ func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @viewlike$two_inputs_one_copy(
// CHECK-LABEL: func.func @viewlike$two_inputs_one_copy(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor
func @viewlike$two_inputs_one_copy(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$two_inputs_one_copy(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.expand_as %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor
%2 = torch.copy.to_vtensor %1 : !torch.vtensor
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @viewlike$two_inputs_two_copies(
// CHECK-LABEL: func.func @viewlike$two_inputs_two_copies(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor
func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.copy.to_tensor %arg1 : !torch.tensor
%2 = torch.aten.expand_as %0, %1 : !torch.tensor, !torch.tensor -> !torch.tensor

View File

@ -1,52 +1,52 @@
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @torch.operator(
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
// CHECK-LABEL: func.func @torch.operator(
func.func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
// CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
%0 = torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
return %0 : !torch.tensor
}
func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> (!torch.LinearParams, !torch.LinearParams) {
func.func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> (!torch.LinearParams, !torch.LinearParams) {
%with_bias = torch.linear_params.create %arg0, %arg1 : !torch.tensor, !torch.tensor
%without_bias = torch.linear_params.create %arg0 : !torch.tensor
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
}
// CHECK: @tensor.default() -> !torch.tensor
func private @tensor.default() -> !torch.tensor
func.func private @tensor.default() -> !torch.tensor
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}
func private @tensor.default_explicit() -> !torch.tensor<*,unk>
func.func private @tensor.default_explicit() -> !torch.tensor<*,unk>
// CHECK: @tensor.value_semantic() -> !torch.vtensor{{$}}
func private @tensor.value_semantic() -> !torch.vtensor<*,unk>
func.func private @tensor.value_semantic() -> !torch.vtensor<*,unk>
// CHECK: @tensor.dtype() -> !torch.tensor<*,si32>
func private @tensor.dtype() -> !torch.tensor<*,si32>
func.func private @tensor.dtype() -> !torch.tensor<*,si32>
// CHECK: @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
func private @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
func.func private @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
// CHECK: @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
func.func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
func.func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: @tuple.empty() -> !torch.tuple<>
func private @tuple.empty() -> !torch.tuple<>
func.func private @tuple.empty() -> !torch.tuple<>
// CHECK: @tuple.one_element() -> !torch.tuple<tensor>
func private @tuple.one_element() -> !torch.tuple<tensor>
func.func private @tuple.one_element() -> !torch.tuple<tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<tensor, tensor>
func private @tuple.two_elements() -> !torch.tuple<tensor, tensor>
func.func private @tuple.two_elements() -> !torch.tuple<tensor, tensor>
// CHECK: @union.empty() -> !torch.union<>
func private @union.empty() -> !torch.union<>
func.func private @union.empty() -> !torch.union<>
// CHECK: @union.one_element() -> !torch.union<tensor>
func private @union.one_element() -> !torch.union<tensor>
func.func private @union.one_element() -> !torch.union<tensor>
// CHECK: @union.two_elements() -> !torch.union<tensor, tensor>
func private @union.two_elements() -> !torch.union<tensor, tensor>
func.func private @union.two_elements() -> !torch.union<tensor, tensor>
// CHECK: @dict() -> !torch.dict<str, tensor>
func private @dict() -> !torch.dict<str, tensor>
func.func private @dict() -> !torch.dict<str, tensor>
// CHECK-LABEL: func @torch.tensor.literal() {
func @torch.tensor.literal() {
// CHECK-LABEL: func.func @torch.tensor.literal() {
func.func @torch.tensor.literal() {
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
@ -54,19 +54,19 @@ func @torch.tensor.literal() {
return
}
// CHECK-LABEL: func @torch.vtensor.literal() {
func @torch.vtensor.literal() {
// CHECK-LABEL: func.func @torch.vtensor.literal() {
func.func @torch.vtensor.literal() {
// CHECK: torch.vtensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
%0 = torch.vtensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
return
}
func @derefine(%arg0: !torch.tensor) -> !torch.optional<tensor> {
func.func @derefine(%arg0: !torch.tensor) -> !torch.optional<tensor> {
%0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<tensor>
return %0 : !torch.optional<tensor>
}
func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
func.func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%0 = torch.prim.If %arg0 -> (!torch.int) {
%1 = torch.aten.add.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %1 : !torch.int
@ -103,7 +103,7 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%none = torch.constant.none
// CHECK: %str = torch.constant.str "some str"
%str = torch.constant.str "some str"
func private @f(%arg0: !torch.nn.Module<"test">) {
func.func private @f(%arg0: !torch.nn.Module<"test">) {
return
}
@ -131,7 +131,7 @@ torch.nn_module {
} : !torch.nn.Module<"test">
func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.shape.calculate {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
torch.shape.calculate.yield %0 : !torch.vtensor
@ -142,7 +142,7 @@ func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
return %0 : !torch.vtensor
}
func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
return
}

View File

@ -6,23 +6,23 @@ torch.class_type @c {
}
// CHECK-LABEL: func private @test_call_method(
// CHECK-LABEL: func.func private @test_call_method(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[RET:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK: return %[[RET]] : !torch.float
func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
func.func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = torch.prim.CallMethod %arg0["test_call_method"] (%arg1) : !torch.nn.Module<"c">, (!torch.float) -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func private @test_call_indirect(
// CHECK-LABEL: func.func private @test_call_indirect(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// Ensure no func.constant.
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK-NEXT: return %[[VAL_2]] : !torch.float
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
func.func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = constant @test_call_method : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
%1 = call_indirect %0(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
return %1 : !torch.float

View File

@ -2,7 +2,7 @@
// -----
func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !torch.tensor {
func.func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !torch.tensor {
%int1 = torch.constant.int 1
// expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}}
%ret = torch.aten.cat %list, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
@ -11,7 +11,7 @@ func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !to
// -----
func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<tensor>,
func.func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<tensor>,
%t: !torch.tensor,
%training: !torch.bool,
%cudnn_enable: !torch.bool,

View File

@ -1,17 +1,17 @@
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
// CHECK-LABEL: func @convert_to_value_semantic_tensors(
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
// CHECK: return %[[RET]] : !torch.tensor<[],f32>
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
}
// CHECK-LABEL: func @convert_to_value_semantic_tensors_list(
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T0:.*]] = torch.copy.to_tensor %[[VT0]] : !torch.tensor
@ -30,7 +30,7 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
// CHECK-SAME: !torch.list<vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor {
func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor {
%t0 = torch.copy.to_tensor %vt0 : !torch.tensor
%t1 = torch.copy.to_tensor %vt1 : !torch.tensor
%t2 = torch.copy.to_tensor %vt2 : !torch.tensor
@ -40,7 +40,7 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
return %ret : !torch.tensor
}
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional(
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float) -> !torch.tensor {
@ -67,7 +67,7 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
// CHECK: }
func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
%ft: !torch.tensor<[4],f32>,
%training: !torch.bool,
%cudnn_enable: !torch.bool,
@ -83,7 +83,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
return %ret: !torch.tensor
}
// CHECK-LABEL: func @reduce_trailing_underscore_inplace_variant(
// CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
// CHECK: %[[C1:.*]] = torch.constant.int 1
@ -97,23 +97,23 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
%c1 = torch.constant.int 1
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
}
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
// CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
// CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
// CHECK: return %[[TENSOR]] : !torch.tensor
func @torch.tensor.literal() -> !torch.tensor {
func.func @torch.tensor.literal() -> !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
return %0 : !torch.tensor
}
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional_list(
// CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDICES]] :
@ -124,13 +124,13 @@ func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor<[2,3],si64>>> -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.uniform_(
// CHECK-LABEL: func.func @torch.aten.uniform_(
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
@ -140,12 +140,12 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
func.func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.bernoulli_.float(
// CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[GENERATOR:.*]] = torch.constant.none
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
@ -155,14 +155,14 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none
%p = torch.constant.float 5.000000e-01
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.fill_.Scalar(
// CHECK-LABEL: func.func @torch.aten.fill_.Scalar(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VALUE:.*]] = torch.constant.int 1
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
@ -171,13 +171,13 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
func.func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%value = torch.constant.int 1
%ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten._index_put_impl_(
// CHECK-LABEL: func.func @torch.aten._index_put_impl_(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -191,7 +191,7 @@ func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[SELF:.*]] : !torch.tensor
func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor {
func.func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor {
%true = torch.constant.bool true
%false = torch.constant.bool false
%indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list<optional<tensor>>
@ -199,7 +199,7 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.copy_(
// CHECK-LABEL: func.func @torch.aten.copy_(
// CHECK-SAME: %[[DST:.*]]: !torch.tensor,
// CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -210,7 +210,7 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[DST]] : !torch.tensor
func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor {
func.func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor {
%false = torch.constant.bool false
%ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
return %ret : !torch.tensor

View File

@ -1,34 +1,34 @@
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[2,3,?],f32>
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor
}
// CHECK-LABEL: func @multiple_use_non_value_tensor(
// CHECK-LABEL: func.func @multiple_use_non_value_tensor(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[NON_VALUE_TENSOR:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[ARG1]] overwrites %[[NON_VALUE_TENSOR]] : !torch.vtensor, !torch.tensor
// CHECK: %[[RESULT:.*]] = torch.copy.to_vtensor %[[NON_VALUE_TENSOR]] : !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @multiple_use_non_value_tensor(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
func.func @multiple_use_non_value_tensor(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
return %0 : !torch.tensor
}
// No conversion on private function.
// CHECK-LABEL: func private @basic_private(
// CHECK-LABEL: func.func private @basic_private(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[CASTED]] : !torch.tensor
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor
@ -38,11 +38,11 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// Call to public function.
// expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %arg0 : tensor<*xf32>
}
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func.func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -51,7 +51,7 @@ func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// Multiple returns.
// expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%ctrue = arith.constant true
cf.cond_br %ctrue, ^bb1, ^bb2
^bb1:

View File

@ -2,7 +2,7 @@
// -----
// CHECK-LABEL: func @prim.if$branch_merge_type_tensor(
// CHECK-LABEL: func.func @prim.if$branch_merge_type_tensor(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T1:.*]]: !torch.tensor,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
@ -18,7 +18,7 @@
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool
func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
func.func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor>
@ -33,7 +33,7 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// -----
// CHECK-LABEL: func @prim.if$branch_merge_type_optional(
// CHECK-LABEL: func.func @prim.if$branch_merge_type_optional(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
@ -46,7 +46,7 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// CHECK: }
// CHECK: return %[[MERGED:.*]] : !torch.optional<tensor>
func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
func.func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%none = torch.constant.none
%optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor>
@ -60,7 +60,7 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// -----
// CHECK-LABEL: func @prim.if$refined_type_conflicting(
// CHECK-LABEL: func.func @prim.if$refined_type_conflicting(
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
@ -73,7 +73,7 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// CHECK: }
// CHECK: return %[[PRED:.*]] : !torch.tensor
func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
func.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
%res = torch.prim.If %pred -> (!torch.tensor) {
@ -88,7 +88,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// -----
// CHECK-LABEL: func @prim.loop$region_arg_to_internal(
// CHECK-LABEL: func.func @prim.loop$region_arg_to_internal(
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> {
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[INDV:.*]] = torch.constant.int 0
@ -105,7 +105,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: return %[[OPTIONAL]] : !torch.optional<tensor>
func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
%int10 = torch.constant.int 10
%int0 = torch.constant.int 0
%true = torch.constant.bool true
@ -120,11 +120,11 @@ func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<te
// -----
// CHECK-LABEL: func @f
// CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
@ -134,16 +134,16 @@ func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
// -----
// CHECK-LABEL: func @f
// CHECK: func private @callee
// CHECK-LABEL: func.func @f
// CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func @f() {
func.func @f() {
builtin.module {
func private @callee(%arg0: !torch.vtensor) {
func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func @caller(%arg0: !torch.vtensor<*,f32>) {
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return

View File

@ -4,7 +4,7 @@
// function (i.e. new code called from visitOperation).
// -----
// CHECK-LABEL: func @aten.arange.start$int64_dtype(
// CHECK-LABEL: func.func @aten.arange.start$int64_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.int,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -14,14 +14,14 @@
// CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func @aten.arange.start$float32_dtype(
// CHECK-LABEL: func.func @aten.arange.start$float32_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.float,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -31,14 +31,14 @@ func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !to
// CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func @aten.arange.start$specified_dtype(
// CHECK-LABEL: func.func @aten.arange.start$specified_dtype(
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -48,7 +48,7 @@ func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) ->
// CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
%int6 = torch.constant.int 6
%none = torch.constant.none
%ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor
@ -56,20 +56,20 @@ func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
}
// -----
// CHECK-LABEL: func @torch.aten.linear(
// CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @aten.sum.dim_IntList(
// CHECK-LABEL: func.func @aten.sum.dim_IntList(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -82,7 +82,7 @@ func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<
// CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
%false = torch.constant.bool false
%none = torch.constant.none
%int0 = torch.constant.int 0
@ -93,14 +93,14 @@ func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
}
// -----
// CHECK-LABEL: func @aten.any.dim(
// CHECK-LABEL: func.func @aten.any.dim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor
@ -108,18 +108,18 @@ func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
}
// -----
// CHECK-LABEL: func @aten.any(
// CHECK-LABEL: func.func @aten.any(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor
return %ret : !torch.vtensor
}
// -----
// CHECK-LABEL: func @torch.aten.zeros(
// CHECK-LABEL: func.func @torch.aten.zeros(
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -127,7 +127,7 @@ func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%int2 = torch.constant.int 2
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
@ -136,19 +136,19 @@ func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
}
// -----
// CHECK-LABEL: func @torch.aten.type_as(
// CHECK-LABEL: func.func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
return %ret: !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.cat(
// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -156,7 +156,7 @@ func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
@ -164,29 +164,29 @@ func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4
}
// -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor(
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.embedding(
// CHECK-LABEL: func.func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -194,7 +194,7 @@ func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) ->
// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
@ -202,14 +202,14 @@ func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !tor
}
// -----
// CHECK-LABEL: func @torch.aten.tensor.float(
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor
@ -217,7 +217,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
}
// -----
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST11:.*]] = torch.constant.int 11
@ -225,7 +225,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none
%int11 = torch.constant.int 11
%false = torch.constant.bool false
@ -234,59 +234,59 @@ func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.softmax.int(
// CHECK-LABEL: func.func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%none = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%int4 = torch.constant.int 4
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.to.dtype(
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
@ -294,7 +294,7 @@ func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !t
// CHECK-SAME: -> !torch.tensor<*,si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none
%false = torch.constant.bool false
%int4 = torch.constant.int 4
@ -303,18 +303,18 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
}
// -----
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.tensor(
// CHECK-LABEL: func.func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -323,7 +323,7 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
// CHECK-SAME: -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
@ -331,7 +331,7 @@ func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
}
// -----
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4
@ -339,7 +339,7 @@ func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%int4 = torch.constant.int 4
%false = torch.constant.bool false

View File

@ -7,35 +7,35 @@
// should go in refine-types-ops.mlir.
// -----
// CHECK-LABEL: func @basic(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @keep_existing_shape_information(
// CHECK-LABEL: func.func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
// CHECK: return %[[TANH]] : !torch.vtensor<[2],f32>
func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
return %1 : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func @propagate_through_multiple_ops(
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH2:.*]] = torch.aten.tanh %[[TANH1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH3:.*]] = torch.tensor_static_info_cast %[[TANH2]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[TANH3]] : !torch.vtensor
func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
%3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor
@ -45,108 +45,108 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte
// -----
// Check rewriting logic in case of mixes of users that do/don't allow type
// refinement.
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement(
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH0]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
return %1, %1 : !torch.vtensor, !torch.vtensor
}
// -----
// CHECK-LABEL: func @type_promotion$same_category_different_width(
// CHECK-LABEL: func.func @type_promotion$same_category_different_width(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],si64> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
func.func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion$different_category(
// CHECK-LABEL: func.func @type_promotion$different_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
func.func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider(
// CHECK-LABEL: func.func @type_promotion$same_category_zero_rank_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
func.func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category(
// CHECK-LABEL: func.func @type_promotion$zero_rank_higher_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 2
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
func.func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion$alpha_wider(
// CHECK-LABEL: func.func @type_promotion$alpha_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
func.func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func @type_promotion_scalar_operation(
// CHECK-LABEL: func.func @type_promotion_scalar_operation(
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
// CHECK: return %[[RET]] : !torch.number
func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
func.func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
return %ret : !torch.number
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
@ -157,14 +157,14 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32>
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
@ -175,23 +175,23 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.
}
// -----
// CHECK-LABEL: func @bf16_result_type(
// CHECK-LABEL: func.func @bf16_result_type(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16>
func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
return %1 : !torch.vtensor<[2],bf16>
}
// -----
// CHECK-LABEL: func @propagate_scalar_type(
// CHECK-LABEL: func.func @propagate_scalar_type(
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
// CHECK: return %[[RET]] : !torch.number
func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
%num = torch.derefine %arg0 : !torch.int to !torch.number
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
return %1 : !torch.number

View File

@ -1,20 +1,20 @@
// RUN: torch-mlir-opt -torch-reify-shape-calculations -split-input-file %s | FileCheck %s
// CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.tanh(
// CHECK: func.func private @__torch_mlir_shape_fn.aten.tanh(
// CHECK-LABEL: func @basic(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<int>) -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<int>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
@ -22,9 +22,9 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// -----
// CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar(
// CHECK: func.func private @__torch_mlir_shape_fn.aten.fill.Scalar(
// CHECK-LABEL: func @valsem_ops(
// CHECK-LABEL: func.func @valsem_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
@ -32,11 +32,11 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
// CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
func.func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor
}
@ -44,10 +44,10 @@ func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// -----
// CHECK: module {
// CHECK-LABEL: func private @__torch_mlir_shape_fn.aten.uniform(
// CHECK-LABEL: func.func private @__torch_mlir_shape_fn.aten.uniform(
// CHECK-SAME: {{.*}}!torch.any)
// CHECK-LABEL: func @adjust_shape_function_arg$torch.any(
// CHECK-LABEL: func.func @adjust_shape_function_arg$torch.any(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
@ -57,11 +57,11 @@ func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ANY:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.any
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<int>, !torch.float, !torch.float, !torch.any) -> !torch.list<int>
// CHECK: %[[SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<int>, !torch.float, !torch.float, !torch.any) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
func.func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
return %0 : !torch.vtensor
@ -74,9 +74,9 @@ func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.f
// callees of the shape functions.
// CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.add.Tensor(
// CHECK: func.func private @__torch_mlir_shape_fn.aten.add.Tensor(
// CHECK-LABEL: func @adjust_shape_function_arg$scalar(
// CHECK-LABEL: func.func @adjust_shape_function_arg$scalar(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -87,11 +87,11 @@ func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.f
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ARG1_SHAPE:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SCALAR_CONVERTED:.*]] = torch.aten.Float.Scalar %[[INT1]] : !torch.int -> !torch.float
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor
@ -100,9 +100,9 @@ func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vten
// -----
// CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.topk(
// CHECK: func.func private @__torch_mlir_shape_fn.aten.topk(
// CHECK-LABEL: func @multiple_results(
// CHECK-LABEL: func.func @multiple_results(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -112,13 +112,13 @@ func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vten
// CHECK: torch.shape.calculate.yield %[[TOP_VALUES]], %[[TOPK_INDICES]] : !torch.tensor, !torch.tensor
// CHECK: } shapes {
// CHECK: %[[ARG_SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.tensor -> !torch.list<int>
// CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<list<int>, list<int>>
// CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = func.call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<list<int>, list<int>>
// CHECK: %[[TOPK_SHAPE:.*]]:2 = torch.prim.TupleUnpack %[[TOPK_SHAPE_TUPLE]] : !torch.tuple<list<int>, list<int>> -> !torch.list<int>, !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[TOPK_SHAPE]]#0, %[[TOPK_SHAPE]]#1 : !torch.list<int>, !torch.list<int>
// CHECK: } : !torch.tensor, !torch.tensor
// CHECK: return %[[RESULTS:.*]]#0, %[[RESULTS]]#1 : !torch.tensor, !torch.tensor
func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
func.func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
@ -128,7 +128,7 @@ func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// -----
// CHECK-LABEL: func @adjust_shape_function_arg$optional(
// CHECK-LABEL: func.func @adjust_shape_function_arg$optional(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate {
@ -138,11 +138,11 @@ func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[SHAPE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SHAPE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[DEREFINED:.*]] = torch.derefine %{{.*}} : !torch.none to !torch.optional<list<int>>
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
// CHECK: %[[SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int2 = torch.constant.int 2
@ -157,7 +157,7 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
// -----
// CHECK-LABEL: func @adjust_shape_function_arg$optional_tensor(
// CHECK-LABEL: func.func @adjust_shape_function_arg$optional_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
@ -184,11 +184,11 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
// CHECK: %[[DEREFINED_NONE1:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE2:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE3:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[BN_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>
// CHECK: %[[BN_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[BN_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
%false = torch.constant.bool false
%true = torch.constant.bool true
%float1.000000e-05 = torch.constant.float 1.000000e-05
@ -201,7 +201,7 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
// -----
// CHECK-LABEL: func @adjust_shape_function_arg$list(
// CHECK-LABEL: func.func @adjust_shape_function_arg$list(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor) -> !torch.list<vtensor>
@ -221,11 +221,11 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
// CHECK: %{{.*}} = torch.aten.append.t %[[ADJUSTED_LIST]], %[[ADJUSTED_ELEMENT]] : !torch.list<optional<list<int>>>, !torch.optional<list<int>> -> !torch.list<optional<list<int>>>
// CHECK: torch.prim.Loop.condition %[[CTRUE]], iter()
// CHECK: } : (!torch.int, !torch.bool) -> ()
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[VAL_15:.*]] : !torch.vtensor
func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
return %1 : !torch.vtensor

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-simplify-shape-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @refine_shape_calculate_result$basic(
// CHECK-LABEL: func.func @refine_shape_calculate_result$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -14,7 +14,7 @@
// CHECK: } : !torch.vtensor<[2,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
func.func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
@ -25,10 +25,10 @@ func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.i
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_one_element(
// CHECK-LABEL: func.func @refine_shape_calculate_result$clobber_one_element(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,2],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
func.func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
@ -47,10 +47,10 @@ func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_all_elements(
// CHECK-LABEL: func.func @refine_shape_calculate_result$clobber_all_elements(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
func.func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
@ -71,10 +71,10 @@ func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor,
}
// Make sure that information previously in the IR is not lost.
// CHECK-LABEL: func @refine_shape_calculate_result$meet_with_existing_information(
// CHECK-LABEL: func.func @refine_shape_calculate_result$meet_with_existing_information(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,3],f32>
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor<[?,3],f32>
func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> {
func.func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.shape.calculate {
@ -87,9 +87,9 @@ func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch
}
// Don't insert static info casts if not needed.
// CHECK-LABEL: func @refine_shape_calculate_result$user_allows_type_refinement(
// CHECK-LABEL: func.func @refine_shape_calculate_result$user_allows_type_refinement(
// CHECK-NOT: torch.tensor_static_info_cast
func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor {
%int2 = torch.constant.int 2
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
%1 = torch.shape.calculate {
@ -102,7 +102,7 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
return %2 : !torch.vtensor
}
// CHECK-LABEL: func @fully_unroll_prim_loop$unroll(
// CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -117,7 +117,7 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
// CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<int>
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor {
func.func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
@ -134,9 +134,9 @@ func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<in
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @fully_unroll_prim_loop$no_unroll(
// CHECK-LABEL: func.func @fully_unroll_prim_loop$no_unroll(
// CHECK: torch.prim.Loop
func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%0 = torch.shape.calculate {
@ -152,13 +152,13 @@ func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$basic(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
// CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
@ -171,10 +171,10 @@ func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.i
}
// Test the different supported mutation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_ops(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$mutation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
@ -192,10 +192,10 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !
}
// Test negative indexes with set_item op.
// CHECK-LABEL: func @abstractly_interpret_list_ops$neg_index_set_item(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$neg_index_set_item(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%int-2 = torch.constant.int -2
@ -211,10 +211,10 @@ func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %a
}
// Test interspersed mutation and evaluation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torch.vtensor) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
@ -230,10 +230,10 @@ func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torc
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(
// CHECK: torch.aten.append.t
// CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
@ -247,13 +247,13 @@ func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.v
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$readonly_op_in_child_region(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$readonly_op_in_child_region(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
@ -276,9 +276,9 @@ func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vt
}
// The mutation in the child region prevents us from abstractly interpreting.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_in_child_region(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$mutation_in_child_region(
// CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
@ -300,7 +300,7 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
return %0 : !torch.vtensor
}
// CHECK-LABEL: func @abstractly_interpret_list_ops$miscompile$list_identity(
// CHECK-LABEL: func.func @abstractly_interpret_list_ops$miscompile$list_identity(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor {
@ -329,7 +329,7 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
// CHECK: } : !torch.vtensor<[3,3],unk>
// CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor
// CHECK: return %[[VAL_13]] : !torch.vtensor
func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor {
func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor {
%true = torch.constant.bool true
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
@ -373,7 +373,7 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// This test should usually not be the one to catch an issue.
// If it does catch an issue then it indicates a more precise unit test that is
// missing.
// CHECK-LABEL: func @basic_integration(
// CHECK-LABEL: func.func @basic_integration(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -388,7 +388,7 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// CHECK: } : !torch.vtensor<[?,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
%true = torch.constant.bool true
%0 = torch.shape.calculate {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor

View File

@ -3,10 +3,10 @@
// This test is largely copied from `finalizing-bufferize` upstream, as it
// covers the same scope.
// CHECK-LABEL: func @eliminate_materializations(
// CHECK-LABEL: func.func @eliminate_materializations(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: return %[[ARG]] : tensor<f32>
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
func.func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
return %1 : tensor<f32>
@ -15,38 +15,38 @@ func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
// Do a basic check of other types. Under the hood they all take the same
// code paths as for !torch.vtensor, so we just spot-check them here.
// CHECK-LABEL: func @eliminate_materializations$torch.bool(
// CHECK-LABEL: func.func @eliminate_materializations$torch.bool(
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: return %[[ARG]] : i1
func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
func.func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
%0 = torch_c.from_i1 %arg0
%1 = torch_c.to_i1 %0
return %1 : i1
}
// CHECK-LABEL: func @eliminate_materializations$torch.int(
// CHECK-LABEL: func.func @eliminate_materializations$torch.int(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: return %[[ARG]] : i64
func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
func.func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
%0 = torch_c.from_i64 %arg0
%1 = torch_c.to_i64 %0
return %1 : i64
}
// CHECK-LABEL: func @eliminate_materializations$torch.float(
// CHECK-LABEL: func.func @eliminate_materializations$torch.float(
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: return %[[ARG]] : f64
func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
func.func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
%0 = torch_c.from_f64 %arg0
%1 = torch_c.to_f64 %0
return %1 : f64
}
// CHECK-LABEL: func @eliminate_materializations$torch.Generator(
// CHECK-LABEL: func.func @eliminate_materializations$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64
// CHECK: }
func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
%0 = torch_c.i64_to_generator %arg0
%1 = torch_c.generator_to_i64 %0
return %1 : i64
@ -54,7 +54,7 @@ func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
// -----
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
func.func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
@ -63,7 +63,7 @@ func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// -----
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
func.func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()

View File

@ -3,48 +3,48 @@
// This test is largely copied from `func-bufferize` upstream, as it covers
// the same scope.
// CHECK-LABEL: func @identity(
// CHECK-LABEL: func.func @identity(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: return %[[ARG]] : tensor<f32>
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
func.func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
return %arg0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @block_arguments(
// CHECK-LABEL: func.func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: cf.br ^bb1(%[[ARG]] : tensor<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
// CHECK: return %[[BBARG]] : tensor<f32>
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
func.func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
cf.br ^bb1(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg: !torch.vtensor<[],f32>):
return %bbarg : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func private @source() -> tensor<f32>
// CHECK-LABEL: func @call_source() -> tensor<f32> {
// CHECK-LABEL: func.func private @source() -> tensor<f32>
// CHECK-LABEL: func.func @call_source() -> tensor<f32> {
// CHECK: %[[RET:.*]] = call @source() : () -> tensor<f32>
// CHECK: return %[[RET]] : tensor<f32>
func private @source() -> !torch.vtensor<[],f32>
func @call_source() -> !torch.vtensor<[],f32> {
func.func private @source() -> !torch.vtensor<[],f32>
func.func @call_source() -> !torch.vtensor<[],f32> {
%0 = call @source() : () -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @call_sink(
// CHECK-LABEL: func.func @call_sink(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
// CHECK: call @sink(%[[ARG]]) : (tensor<f32>) -> ()
// CHECK: return
func private @sink(!torch.vtensor<[],f32>)
func @call_sink(%arg0: !torch.vtensor<[],f32>) {
func.func private @sink(!torch.vtensor<[],f32>)
func.func @call_sink(%arg0: !torch.vtensor<[],f32>) {
call @sink(%arg0) : (!torch.vtensor<[],f32>) -> ()
return
}
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> {
// CHECK-LABEL: func.func @unconverted_op_in_body() -> tensor<f32> {
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: return %[[BUILTIN_TENSOR]] : tensor<f32>
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
func.func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
%0 = "test.source"() : () -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
@ -53,7 +53,7 @@ func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
// Because this pass updates block arguments, it needs to also atomically
// update all terminators and issue an error if that is not possible.
func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
func.func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = arith.constant true
cf.cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg0: !torch.vtensor<[],f32>):
@ -72,7 +72,7 @@ func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtens
// CHECK: while
// CHECK: scf.while
// CHECK: scf.condition
func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
func.func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
%c2_i64 = arith.constant 2 : i64
%0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
%1 = arith.cmpi slt, %arg2, %arg1 : i64
@ -88,30 +88,30 @@ func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
// Do a basic check of other types. Under the hood they all take the same
// code paths as for !torch.vtensor, so we just spot-check them here.
// CHECK-LABEL: func @identity$torch.bool(
// CHECK-LABEL: func.func @identity$torch.bool(
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: return %[[ARG]] : i1
func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
func.func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
return %arg0 : !torch.bool
}
// CHECK-LABEL: func @identity$torch.int(
// CHECK-LABEL: func.func @identity$torch.int(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: return %[[ARG]] : i64
func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
func.func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
return %arg0 : !torch.int
}
// CHECK-LABEL: func @identity$torch.float(
// CHECK-LABEL: func.func @identity$torch.float(
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: return %[[ARG]] : f64
func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
func.func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float
}
// CHECK-LABEL: func @identity$torch.Generator(
// CHECK-LABEL: func.func @identity$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64
func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator {
func.func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator {
return %arg0 : !torch.Generator
}

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @builtin_tensor_interop(
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK-LABEL: func.func @builtin_tensor_interop(
func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
%0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],si8>

View File

@ -2,7 +2,7 @@
// -----
func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
func.func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<*,f32>
@ -11,7 +11,7 @@ func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
// -----
func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
func.func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],unk>
@ -20,7 +20,7 @@ func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
// -----
func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) {
func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) {
// expected-error@+2 {{unsupported by backend lowering: `torch.operator` op}}
// expected-note@+1 {{this is likely due to a missing op that needs to be generated by torch_ods_gen.py}}
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-linalg-on-tensors-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @mm
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: func.func @mm
func.func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32
@ -23,7 +23,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module {
func @disallowed() {
func.func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> ()
return
@ -46,7 +46,7 @@ module {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
func.func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor
}

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @tanh
func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: func.func @tanh
func.func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@ -12,7 +12,7 @@ func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module {
func @disallowed() {
func.func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> ()
return
@ -35,7 +35,7 @@ module {
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
func.func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor
}

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
// CHECK-LABEL: func @f() -> i64 {
// CHECK-LABEL: func.func @f() -> i64 {
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
@ -13,7 +13,7 @@
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64
module {
func @f() -> i64 {
func.func @f() -> i64 {
%seed = torch_c.get_next_seed : () -> i64
return %seed : i64
}

View File

@ -1,43 +1,43 @@
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
// CHECK-LABEL: func @f(
// CHECK-LABEL: func.func @f(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
func.func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32>
}
// -----
// CHECK-LABEL: func @i(
// CHECK-LABEL: func.func @i(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
func.func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64>
}
// -----
// CHECK-LABEL: func @elemental_type(
// CHECK-LABEL: func.func @elemental_type(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
// CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 {
func.func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64>
return %0 : i64
}
// -----
// CHECK-LABEL: func @multiple_return_values(
// CHECK-LABEL: func.func @multiple_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
@ -50,13 +50,13 @@ func @elemental_type(%arg0: memref<i64>) -> i64 {
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
// CHECK: return
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
func.func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
}
// -----
// CHECK-LABEL: func @two_return_values(
// CHECK-LABEL: func.func @two_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xi64>)
// CHECK-SAME: attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
@ -67,6 +67,6 @@ func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2:
// CHECK-SAME: : (memref<*xf32>, memref<*xi64>) -> ()
// CHECK: return
func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
}

View File

@ -21,7 +21,7 @@ recursivescriptmodule = torch.jit.script(test_module)
annotator = ClassAnnotator()
class_type = recursivescriptmodule._c._type()
# CHECK: func private @__torch__.TestModule.forward(
# CHECK: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>},
# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>}

View File

@ -13,18 +13,18 @@ mb = ModuleBuilder()
# Interesting test case, where a function calls a method.
# CHECK-LABEL: func private @__torch__.TestModule.forward
# CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[F:.*]] = constant @__torch__.calls_method : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none
# CHECK: %[[RET:.*]] = call_indirect %[[F]](%[[ARG0]], %[[ARG1]]) : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none
# CHECK: return %[[RET]] : !torch.none
# CHECK: }
# CHECK-LABEL: func private @__torch__.TestModule.method
# CHECK-LABEL: func.func private @__torch__.TestModule.method
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[RET:.*]] = torch.constant.none
# CHECK: return %[[RET]] : !torch.none
# CHECK: }
# CHECK-LABEL: func private @__torch__.calls_method
# CHECK-LABEL: func.func private @__torch__.calls_method
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[ARG0]]["method"] (%[[ARG1]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.tensor) -> !torch.none
# CHECK: return %[[RET]] : !torch.none

View File

@ -11,13 +11,13 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func private @__torch__.TestModule.forward
# CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[VAL_2:.*]] = constant @__torch__.identity : (!torch.tensor) -> !torch.tensor
# CHECK: %[[VAL_3:.*]] = call_indirect %[[VAL_2]](%[[ARG1]]) : (!torch.tensor) -> !torch.tensor
# CHECK: return %[[VAL_3]] : !torch.tensor
# CHECK: }
# CHECK-LABEL: func private @__torch__.identity
# CHECK-LABEL: func.func private @__torch__.identity
# CHECK-SAME: (%[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: return %[[ARG]] : !torch.tensor
# CHECK: }

View File

@ -17,7 +17,7 @@ class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# CHECK-LABEL: func private @__torch__.TestModule.forward(
# CHECK-LABEL: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>

View File

@ -21,7 +21,7 @@ mb = ModuleBuilder()
# Given how systematic this is, we don't treat the symbol names as opaque (i.e.
# we don't need to capture their names when FileCheck testing).
# CHECK-LABEL: func private @__torch__.TestModule.forward
# CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[X:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: return %[[X]] : !torch.tensor
# CHECK: }

View File

@ -17,7 +17,7 @@ class TestModule(torch.nn.Module):
self.t1 = torch.ones(1)
self.t2 = torch.ones(1)
# CHECK-LABEL: func private @__torch__.TestModule.forward(
# CHECK-LABEL: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !torch.none {
def forward(self):
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
@ -25,7 +25,7 @@ class TestModule(torch.nn.Module):
self.t1 = self.t2
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
self.callee(self.t1, self.t2)
# CHECK-LABEL: func private @__torch__.TestModule.callee(
# CHECK-LABEL: func.func private @__torch__.TestModule.callee(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
# CHECK-SAME: %[[X:.*]]: !torch.tensor,
# CHECK-SAME: %[[Y:.*]]: !torch.tensor

View File

@ -24,7 +24,7 @@ class TestModule(torch.nn.Module):
self.s1 = Submodule(1)
self.s2 = Submodule(2)
# CHECK-LABEL: func private @{{.*}}TestModule.forward
# CHECK-LABEL: func.func private @{{.*}}TestModule.forward
def forward(self, b: bool):
# Modules with the same class can be selected between.
# CHECK: %[[MOD:.*]] = torch.prim.If

View File

@ -18,7 +18,7 @@ class BasicClass:
def __init__(self, x: int):
self.x = x
# CHECK-LABEL: func @__torch__.prim_CreateObject(
# CHECK-LABEL: func.func @__torch__.prim_CreateObject(
# CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.nn.Module<"__torch__.BasicClass"> {
# CHECK: %[[OBJECT:.*]] = torch.prim.CreateObject !torch.nn.Module<"__torch__.BasicClass">
# CHECK: %[[NONE:.*]] = torch.prim.CallMethod %[[OBJECT]]["__init__"] (%[[ARG0]]) : !torch.nn.Module<"__torch__.BasicClass">, (!torch.int) -> !torch.none

View File

@ -9,7 +9,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.add3
# CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently
# this includes naked constants, return and the function itself. This heuristic

View File

@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, NamedTuple, Dict
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<str, tensor> {
# CHECK-LABEL: func.func @__torch__.dict_literal_empty() -> !torch.dict<str, tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<str, tensor>
# CHECK: return %[[DICT]] : !torch.dict<str, tensor>
@mb.import_function
@ -21,7 +21,7 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]:
return {}
# CHECK-LABEL: func @__torch__.dict_literal(
# CHECK-LABEL: func.func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<str, optional<tensor>> {

View File

@ -10,7 +10,7 @@ from utils import create_script_function
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.refined_block_arg(
# CHECK-LABEL: func.func @__torch__.refined_block_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
# CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor

View File

@ -11,7 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return(
# CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
# CHECK: return %[[RET]] : !torch.optional<int>
@ -20,14 +20,14 @@ mb = ModuleBuilder()
def optional_return(i: int) -> typing.Optional[int]:
return i
# CHECK-LABEL: func @__torch__.optional_arg(
# CHECK-LABEL: func.func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.none {
@mb.import_function
@torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None:
return
# CHECK-LABEL: func @__torch__.calls_optional_arg(
# CHECK-LABEL: func.func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.none {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<int>) -> !torch.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -32,7 +32,7 @@ def prim_If(b: bool, i: int):
else:
return i * i
# CHECK-LABEL: func @__torch__.prim_If_derefine(
# CHECK-LABEL: func.func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none

View File

@ -9,7 +9,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.f(
# CHECK-LABEL: func.func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<tensor> {
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<tensor>

View File

@ -11,7 +11,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.prim_Loop_forlike(
# CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
# CHECK: %[[F_INIT:.*]] = torch.constant.float 0.000000e+00
@ -29,7 +29,7 @@ def prim_Loop_forlike(n: int):
f += i
return f
# CHECK-LABEL: func @__torch__.prim_Loop_whilelike(
# CHECK-LABEL: func.func @__torch__.prim_Loop_whilelike(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[F_INIT:.*]] = torch.constant.float 3.200000e+00
# CHECK: %[[MAX_ITERATIONS:.*]] = torch.constant.int 9223372036854775807
@ -49,7 +49,7 @@ def prim_Loop_whilelike(n: int):
f = f * f
return f
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
# CHECK-LABEL: func.func @__torch__.prim_Loop_derefine(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[TRUE:.*]] = torch.constant.bool true
# CHECK: %[[NONE:.*]] = torch.constant.none

View File

@ -15,7 +15,7 @@ import typing
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
# CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
@ -25,7 +25,7 @@ mb = ModuleBuilder()
def prim_NumToTensor(i: int):
return _to_tensor(i)
# CHECK-LABEL: func @__torch__.prim_Print(
# CHECK-LABEL: func.func @__torch__.prim_Print(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[STR:.*]] = torch.constant.str "x"
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !torch.str, !torch.tensor
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
def prim_Print(x):
print("x", x)
# CHECK-LABEL: func @__torch__.prim_RaiseException() -> !torch.none {
# CHECK-LABEL: func.func @__torch__.prim_RaiseException() -> !torch.none {
# CHECK: %[[ERRORSTR:.*]] = torch.constant.str "Error"
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !torch.none
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
@ -44,7 +44,7 @@ def prim_Print(x):
def prim_RaiseException():
raise Exception("Error")
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
# CHECK-LABEL: func.func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.int {
# CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[C3:.*]] = torch.constant.int 3
@ -63,7 +63,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return 3
return i
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-LABEL: func.func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<int, int>) -> !torch.int {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<int, int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#0 : !torch.int
@ -73,7 +73,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
val, _ = tup
return val
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-LABEL: func.func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<tensor, tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
@ -82,7 +82,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
return tup[0]
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
# CHECK-LABEL: func.func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.list<int>) -> !torch.int {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#1 : !torch.int
@ -92,7 +92,7 @@ def prim_ListUnpack(l: typing.List[int]):
_, val, _ = l
return val
# CHECK-LABEL: func @__torch__.prim_dtype(
# CHECK-LABEL: func.func @__torch__.prim_dtype(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int {
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !torch.tensor -> !torch.int
# CHECK: return %[[RET]] : !torch.int
@ -101,7 +101,7 @@ def prim_ListUnpack(l: typing.List[int]):
def prim_dtype(x):
return x.dtype
# CHECK-LABEL: func @__torch__.prim_layout(
# CHECK-LABEL: func.func @__torch__.prim_layout(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int {
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !torch.tensor -> !torch.int
# CHECK: return %[[RET]] : !torch.int
@ -110,7 +110,7 @@ def prim_dtype(x):
def prim_layout(x):
return x.layout
# CHECK-LABEL: func @__torch__.prim_device(
# CHECK-LABEL: func.func @__torch__.prim_device(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.Device {
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !torch.tensor -> !torch.Device
# CHECK: return %[[RET]] : !torch.Device
@ -119,7 +119,7 @@ def prim_layout(x):
def prim_device(x):
return x.device
# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-LABEL: func.func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
@ -133,7 +133,7 @@ def prim_device(x):
def prim_min(x: int):
return min(x), min(x,x), min(x, x, x)
# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-LABEL: func.func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
@ -147,7 +147,7 @@ def prim_min(x: int):
def prim_max(x: int):
return max(x), max(x,x), max(x, x, x)
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<int> {
# CHECK-LABEL: func.func @__torch__.prim_Constant_list() -> !torch.list<int> {
# CHECK: %[[A:.*]] = torch.constant.int 1
# CHECK: %[[B:.*]] = torch.constant.int 2
# CHECK: %[[C:.*]] = torch.constant.int 3

View File

@ -14,7 +14,7 @@ mb = ModuleBuilder()
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
('f2', Optional[torch.Tensor])])
# CHECK-LABEL: func @__torch__.tuple(
# CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<tensor, tensor> {
@ -27,7 +27,7 @@ def tuple(t0, t1):
return t0, t1
# CHECK-LABEL: func @__torch__.tuple_optional(
# CHECK-LABEL: func.func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
@ -44,7 +44,7 @@ def tuple_optional(
return t0, t1
# CHECK-LABEL: func @__torch__.namedtuple_optional(
# CHECK-LABEL: func.func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
@ -59,7 +59,7 @@ def namedtuple_optional(
return NT(t0, t1)
# CHECK-LABEL: func @__torch__.tuple_construct_arg_needs_refinement(
# CHECK-LABEL: func.func @__torch__.tuple_construct_arg_needs_refinement(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<tensor, tensor> {
# CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32>

View File

@ -11,7 +11,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.f(
# CHECK-LABEL: func.func @__torch__.f(
# CHECK-SAME: %{{.*}}: !torch.union<float, int>) -> !torch.none {
@mb.import_function