mlir: bump llvm tag to 5380e3 (#856)

In addition to updating the llvm-project submodule, this patch also:

1. updates shape functions and tests so that `func` and `call`
   operations refer to the `func` dialect
2. avoid duplicate registration of dialects
pull/861/head snapshot-20220516.455
Ashay Rane 2022-05-16 12:54:35 -07:00 committed by GitHub
parent cfc1a6515c
commit bb52a460cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1368 additions and 1369 deletions

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s // RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-bufferize %s | FileCheck %s
// ----- // -----
// CHECK-LABEL: func @scan_1d_inclusive( // CHECK-LABEL: func.func @scan_1d_inclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
@ -16,7 +16,7 @@
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> // CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32> // CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) { ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32): ^bb0(%arg0 : i32, %arg1 : i32):
@ -27,7 +27,7 @@ func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
} }
// ----- // -----
// CHECK-LABEL: func @scan_1d_exclusive( // CHECK-LABEL: func.func @scan_1d_exclusive(
// CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>,
// CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> // CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32>
@ -44,7 +44,7 @@ func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> // CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32>
// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32> // CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref<i32>
// CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32> // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor<i32>
func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) { func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor<i32>) -> (tensor<128xi32>, tensor<i32>) {
%ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false) %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false)
ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) { ins(%in : tensor<128xi32>) outs(%out, %acc: tensor<128xi32>, tensor<i32>) {
^bb0(%arg0 : i32, %arg1 : i32): ^bb0(%arg0 : i32, %arg1 : i32):
@ -55,7 +55,7 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
} }
// ----- // -----
// CHECK-LABEL: func @scatter_update_scalar_1D( // CHECK-LABEL: func.func @scatter_update_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
@ -71,7 +71,7 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_update_scalar_1D( func.func @scatter_update_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> { %updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true) %0 = tm_tensor.scatter unique_indices(true)
@ -83,7 +83,7 @@ func @scatter_update_scalar_1D(
return %0 : tensor<8xi32> return %0 : tensor<8xi32>
} }
// CHECK-LABEL: func @scatter_add_scalar_1D( // CHECK-LABEL: func.func @scatter_add_scalar_1D(
// CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>,
// CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>,
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
@ -101,7 +101,7 @@ func @scatter_update_scalar_1D(
// CHECK: } // CHECK: }
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32>
// CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32>
func @scatter_add_scalar_1D( func.func @scatter_add_scalar_1D(
%original: tensor<8xi32>, %indices: tensor<3x1xi32>, %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> { %updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = tm_tensor.scatter unique_indices(true) %0 = tm_tensor.scatter unique_indices(true)

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s // RUN: torch-mlir-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @tensor.cast( // CHECK-LABEL: func.func @tensor.cast(
func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> { func.func @tensor.cast(%arg0: tensor<128xi32>) -> tensor<128xi32> {
%init = linalg.init_tensor [128] : tensor<128xi32> %init = linalg.init_tensor [128] : tensor<128xi32>
%c0 = linalg.init_tensor [] : tensor<i32> %c0 = linalg.init_tensor [] : tensor<i32>

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s // RUN: torch-mlir-dialects-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s
func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) { func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32> %c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(true) tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) { ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
@ -10,7 +10,7 @@ func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
} }
return return
} }
// CHECK-LABEL: func @scan_1d_inclusive // CHECK-LABEL: func.func @scan_1d_inclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
@ -33,7 +33,7 @@ func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
// ----- // -----
func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) { func.func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32> %c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(false) tm_tensor.scan dimension(0) inclusive(false)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) { ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
@ -43,7 +43,7 @@ func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
} }
return return
} }
// CHECK-LABEL: func @scan_1d_exclusive // CHECK-LABEL: func.func @scan_1d_exclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
@ -66,7 +66,7 @@ func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
// ----- // -----
func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
%t0 = memref.alloc() : memref<32xi32> %t0 = memref.alloc() : memref<32xi32>
tm_tensor.scan dimension(0) inclusive(true) tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) { ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
@ -76,7 +76,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
} }
return return
} }
// CHECK-LABEL: func @scan_2d // CHECK-LABEL: func.func @scan_2d
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
@ -102,7 +102,7 @@ func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
// ----- // -----
func @scatter_update_scalar_1D( func.func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>, %original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) { %updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -113,7 +113,7 @@ func @scatter_update_scalar_1D(
} }
return return
} }
// CHECK-LABEL: func @scatter_update_scalar_1D // CHECK-LABEL: func.func @scatter_update_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -128,7 +128,7 @@ func @scatter_update_scalar_1D(
// ----- // -----
func @scatter_add_scalar_2D( func.func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>, %original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) { %updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -140,7 +140,7 @@ func @scatter_add_scalar_2D(
} }
return return
} }
// CHECK-LABEL: func @scatter_add_scalar_2D // CHECK-LABEL: func.func @scatter_add_scalar_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -159,7 +159,7 @@ func @scatter_add_scalar_2D(
// ----- // -----
func @scatter_update_slice_2D( func.func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>, %original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) { %updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -170,7 +170,7 @@ func @scatter_update_slice_2D(
} }
return return
} }
// CHECK: func @scatter_update_slice_2D // CHECK: func.func @scatter_update_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -189,7 +189,7 @@ func @scatter_update_slice_2D(
// ----- // -----
func @scatter_add_scalar_1D( func.func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>, %original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) { %updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -201,7 +201,7 @@ func @scatter_add_scalar_1D(
} }
return return
} }
// CHECK-LABEL: func @scatter_add_scalar_1D // CHECK-LABEL: func.func @scatter_add_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -218,7 +218,7 @@ func @scatter_add_scalar_1D(
// ----- // -----
func @scatter_add_slice_2D( func.func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>, %original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) { %updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -230,7 +230,7 @@ func @scatter_add_slice_2D(
} }
return return
} }
// CHECK: func @scatter_add_slice_2D // CHECK: func.func @scatter_add_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -248,7 +248,7 @@ func @scatter_add_slice_2D(
// ----- // -----
func @scatter_update_scalar_dynamic_1D( func.func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>, %original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) { %updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -259,7 +259,7 @@ func @scatter_update_scalar_dynamic_1D(
} }
return return
} }
// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D // CHECK-LABEL: func.func @scatter_update_scalar_dynamic_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -274,7 +274,7 @@ func @scatter_update_scalar_dynamic_1D(
// ----- // -----
func @scatter_add_scalar_dynamic_2D( func.func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>, %original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) { %updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -286,7 +286,7 @@ func @scatter_add_scalar_dynamic_2D(
} }
return return
} }
// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D // CHECK-LABEL: func.func @scatter_add_scalar_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -305,7 +305,7 @@ func @scatter_add_scalar_dynamic_2D(
// ----- // -----
func @scatter_update_slice_dynamic_2D( func.func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>, %original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) { %updates: memref<?x?xi32>) {
tm_tensor.scatter unique_indices(true) tm_tensor.scatter unique_indices(true)
@ -316,7 +316,7 @@ func @scatter_update_slice_dynamic_2D(
} }
return return
} }
// CHECK: func @scatter_update_slice_dynamic_2D // CHECK: func.func @scatter_update_slice_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
@ -333,7 +333,7 @@ func @scatter_update_slice_dynamic_2D(
// ----- // -----
func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
tm_tensor.scatter tm_tensor.scatter
unique_indices(true) unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
@ -344,7 +344,7 @@ func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>,
return return
} }
// CHECK-LABEL: func @scatter_partial_slices // CHECK-LABEL: func.func @scatter_partial_slices
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-dialects-opt -split-input-file -verify-diagnostics %s // RUN: torch-mlir-dialects-opt -split-input-file -verify-diagnostics %s
func @scatter_mixed_tensor_memref( func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>, %update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -16,7 +16,7 @@ func @scatter_mixed_tensor_memref(
// ----- // -----
func @scatter_mixed_tensor_memref( func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : memref<?x1xi32>, %update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -32,7 +32,7 @@ func @scatter_mixed_tensor_memref(
// ----- // -----
func @scatter_extra_outputs( func.func @scatter_extra_outputs(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { %original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
// expected-error @+1 {{expected number of outputs to be same as the number of results}} // expected-error @+1 {{expected number of outputs to be same as the number of results}}
@ -48,7 +48,7 @@ func @scatter_extra_outputs(
// ----- // -----
func @scatter_mixed_tensor_memref( func.func @scatter_mixed_tensor_memref(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) -> tensor<?x?xf32> { %original : memref<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
@ -64,7 +64,7 @@ func @scatter_mixed_tensor_memref(
// ----- // -----
func @scatter_output_type_mismatch( func.func @scatter_output_type_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<4x?xf32> { %original : tensor<?x?xf32>) -> tensor<4x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}} // expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
@ -80,7 +80,7 @@ func @scatter_output_type_mismatch(
// ----- // -----
func @scatter_mixed_tensor_memref( func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>, %update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : memref<?x?xf32>) { %original : memref<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
@ -96,7 +96,7 @@ func @scatter_mixed_tensor_memref(
// ----- // -----
func @scatter_mixed_tensor_memref( func.func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : memref<?x1xi32>, %update : memref<?x?xf32>, %indices : memref<?x1xi32>,
%original : tensor<?x?xf32>) { %original : tensor<?x?xf32>) {
// expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
@ -112,7 +112,7 @@ func @scatter_mixed_tensor_memref(
// ----- // -----
func @scatter_dim_mismatch( func.func @scatter_dim_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<48x1xi32>, %update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
@ -128,7 +128,7 @@ func @scatter_dim_mismatch(
// ----- // -----
func @scatter_dim_mismatch( func.func @scatter_dim_mismatch(
%update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
@ -144,7 +144,7 @@ func @scatter_dim_mismatch(
// ----- // -----
func @scatter_dim_mismatch( func.func @scatter_dim_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>, %update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{op update value rank exceeds the rank of the original value}} // expected-error @+1 {{op update value rank exceeds the rank of the original value}}
@ -160,7 +160,7 @@ func @scatter_dim_mismatch(
// ----- // -----
func @scatter_dim_mismatch( func.func @scatter_dim_mismatch(
%update : tensor<?x4xf32>, %indices : tensor<?x1xi32>, %update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> { %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}} // expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
@ -176,7 +176,7 @@ func @scatter_dim_mismatch(
// ----- // -----
func @scatter_region_type_mismatch( func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> { %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{expected region to have scalar argument of integer or float types}} // expected-error @+1 {{expected region to have scalar argument of integer or float types}}
@ -193,7 +193,7 @@ func @scatter_region_type_mismatch(
// ----- // -----
func @scatter_region_type_mismatch( func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> { %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
@ -210,7 +210,7 @@ func @scatter_region_type_mismatch(
// ----- // -----
func @scatter_region_type_mismatch( func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi32>) -> tensor<?x?xi32> { %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
// expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}} // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
@ -227,7 +227,7 @@ func @scatter_region_type_mismatch(
// ----- // -----
func @scatter_region_type_mismatch( func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi32>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
@ -244,7 +244,7 @@ func @scatter_region_type_mismatch(
// ----- // -----
func @scatter_region_type_mismatch( func.func @scatter_region_type_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected region to have two arguments}} // expected-error @+1 {{expected region to have two arguments}}
@ -261,7 +261,7 @@ func @scatter_region_type_mismatch(
// ----- // -----
func @scatter_yield_mismatch( func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true) %0 = tm_tensor.scatter unique_indices(true)
@ -278,7 +278,7 @@ func @scatter_yield_mismatch(
// ----- // -----
func @scatter_yield_mismatch( func.func @scatter_yield_mismatch(
%update : tensor<?x?xi64>, %indices : tensor<?x1xi32>, %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
%0 = tm_tensor.scatter unique_indices(true) %0 = tm_tensor.scatter unique_indices(true)
@ -295,7 +295,7 @@ func @scatter_yield_mismatch(
// ----- // -----
func @scatter_index_depth_dynamic( func.func @scatter_index_depth_dynamic(
%update : tensor<?x?xi64>, %indices : tensor<?x?xi32>, %update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{expected index depth is static}} // expected-error @+1 {{expected index depth is static}}
@ -312,7 +312,7 @@ func @scatter_index_depth_dynamic(
// ----- // -----
func @scatter_original_rank_mismatch( func.func @scatter_original_rank_mismatch(
%update : tensor<?xi64>, %indices : tensor<?x1xi32>, %update : tensor<?xi64>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xi64>) -> tensor<?x?xi64> { %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
// expected-error @+1 {{op index depth and update value does not cover rank of original value}} // expected-error @+1 {{op index depth and update value does not cover rank of original value}}

@ -1 +1 @@
Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067 Subproject commit 5380e30e047bbac9b2cceb69162eb8db1e1a7abf

File diff suppressed because it is too large Load Diff

View File

@ -116,7 +116,6 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
unknownLoc(mlirLocationUnknownGet(context)) { unknownLoc(mlirLocationUnknownGet(context)) {
// TODO: Rework this once dialect registration C-APIs are in place. // TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162 // https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
torchMlirRegisterAllDialects(context); torchMlirRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context); registerPythonSysStderrDiagnosticHandler(context);

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.mm$basic( // CHECK-LABEL: func.func @torch.aten.mm$basic(
// CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { // CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -22,7 +22,7 @@
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32> // CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
// CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32> // CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32>
func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> {
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
return %0 : !torch.vtensor<[?,2],f32> return %0 : !torch.vtensor<[?,2],f32>
} }
@ -30,7 +30,7 @@ func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
// ----- // -----
// If the operands are missing dtype, we cannot lower it. // If the operands are missing dtype, we cannot lower it.
func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}} // expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
@ -40,7 +40,7 @@ func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torc
// Correctly handle the case that operands are statically the wrong rank // Correctly handle the case that operands are statically the wrong rank
// (rank 1 vs rank 2 expected for matmul.) // (rank 1 vs rank 2 expected for matmul.)
func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// expected-error@+1 {{failed to legalize}} // expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
@ -49,7 +49,7 @@ func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1:
// ----- // -----
// If the result is missing dtype, we cannot lower it. // If the result is missing dtype, we cannot lower it.
func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { func.func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}} // expected-error@+1 {{failed to legalize}}
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
@ -57,20 +57,20 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
// ----- // -----
// CHECK-LABEL: func @torch.aten.Int.Tensor$zero_rank // CHECK-LABEL: func.func @torch.aten.Int.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64> // CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64> // CHECK: %[[EXT:.*]] = tensor.extract %[[I]][] : tensor<i64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] // CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int // CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int { func.func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Int.Tensor$non_zero_rank // CHECK-LABEL: func.func @torch.aten.Int.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.int {
// CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64> // CHECK: %[[I:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],si64> -> tensor<?x?xi64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -88,27 +88,27 @@ func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64> // CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi64>
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] // CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.int // CHECK: return %[[RET]] : !torch.int
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int { func.func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$zero_rank // CHECK-LABEL: func.func @torch.aten.Float.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64> // CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f64> -> tensor<f64>
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64> // CHECK: %[[EXT:.*]] = tensor.extract %[[F]][] : tensor<f64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]] // CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float // CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float { func.func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float %0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Float.Tensor$non_zero_rank // CHECK-LABEL: func.func @torch.aten.Float.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.float {
// CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64> // CHECK: %[[F:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f64> -> tensor<?x?xf64>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -126,27 +126,27 @@ func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch
// CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64> // CHECK: %[[EXT:.*]] = tensor.extract %[[F]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xf64>
// CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]] // CHECK: %[[RET:.*]] = torch_c.from_f64 %[[EXT]]
// CHECK: return %[[RET]] : !torch.float // CHECK: return %[[RET]] : !torch.float
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float { func.func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float %0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank // CHECK-LABEL: func.func @torch.aten.Bool.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1> // CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1> // CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]] // CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RES]] : !torch.bool // CHECK: return %[[RES]] : !torch.bool
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool { func.func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool %0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[],i1> -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$non_zero_rank // CHECK-LABEL: func.func @torch.aten.Bool.Tensor$non_zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool { // CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1> // CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
@ -164,59 +164,59 @@ func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.b
// CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1> // CHECK: %[[EXT:.*]] = tensor.extract %[[I]][%[[ZERO]], %[[ZERO]]] : tensor<?x?xi1>
// CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]] // CHECK: %[[RET:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RET]] : !torch.bool // CHECK: return %[[RET]] : !torch.bool
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool { func.func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool %0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// ----- // -----
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> { // CHECK: func.func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]] // CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64> // CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
// CHECK: %[[FILLVEC:.*]] = linalg.fill ins(%[[INI64]] : i64) outs(%[[NEWVEC]] : tensor<i64>) -> tensor<i64> // CHECK: %[[FILLVEC:.*]] = linalg.fill ins(%[[INI64]] : i64) outs(%[[NEWVEC]] : tensor<i64>) -> tensor<i64>
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64> // CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64> // CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> { func.func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
%0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64> %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64> return %0 : !torch.vtensor<[],si64>
} }
// ----- // -----
// CHECK-LABEL: func @torch.tensor_static_info_cast$basic( // CHECK-LABEL: func.func @torch.tensor_static_info_cast$basic(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> { // CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> {
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[VALUE_T]] : !torch.vtensor<[?],f32> -> tensor<?xf32> // CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[VALUE_T]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
// CHECK: %[[T_CAST:.*]] = tensor.cast %[[T]] : tensor<?xf32> to tensor<4xf32> // CHECK: %[[T_CAST:.*]] = tensor.cast %[[T]] : tensor<?xf32> to tensor<4xf32>
// CHECK: %[[VALUE_T_CAST:.*]] = torch_c.from_builtin_tensor %[[T_CAST]] : tensor<4xf32> -> !torch.vtensor<[4],f32> // CHECK: %[[VALUE_T_CAST:.*]] = torch_c.from_builtin_tensor %[[T_CAST]] : tensor<4xf32> -> !torch.vtensor<[4],f32>
// CHECK: return %[[VALUE_T_CAST]] : !torch.vtensor<[4],f32> // CHECK: return %[[VALUE_T_CAST]] : !torch.vtensor<[4],f32>
func @torch.tensor_static_info_cast$basic(%t: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> { func.func @torch.tensor_static_info_cast$basic(%t: !torch.vtensor<[?],f32>) -> !torch.vtensor<[4],f32> {
%t_cast = torch.tensor_static_info_cast %t : !torch.vtensor<[?],f32> to !torch.vtensor<[4],f32> %t_cast = torch.tensor_static_info_cast %t : !torch.vtensor<[?],f32> to !torch.vtensor<[4],f32>
return %t_cast : !torch.vtensor<[4],f32> return %t_cast : !torch.vtensor<[4],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.neg // CHECK-LABEL: func.func @torch.aten.neg
// CHECK: linalg.generic {{.*}} { // CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f32 // CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f32
// CHECK-NEXT: linalg.yield %[[NEG]] : f32 // CHECK-NEXT: linalg.yield %[[NEG]] : f32
// CHECK-NEXT: } -> tensor<?x?xf32> // CHECK-NEXT: } -> tensor<?x?xf32>
func @torch.aten.neg(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.neg(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.neg.bf16 // CHECK-LABEL: func.func @torch.aten.neg.bf16
// CHECK: linalg.generic {{.*}} { // CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: bf16, %{{.*}}: bf16): // CHECK-NEXT: ^bb0(%[[LHS:.*]]: bf16, %{{.*}}: bf16):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : bf16 // CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : bf16
// CHECK-NEXT: linalg.yield %[[NEG]] : bf16 // CHECK-NEXT: linalg.yield %[[NEG]] : bf16
// CHECK-NEXT: } -> tensor<?x?xbf16> // CHECK-NEXT: } -> tensor<?x?xbf16>
func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vtensor<[?,?],bf16> { func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vtensor<[?,?],bf16> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16> %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16>
return %0 : !torch.vtensor<[?,?],bf16> return %0 : !torch.vtensor<[?,?],bf16>
} }

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @elementwise$unary( // CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32> // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [] : tensor<f32>
@ -14,12 +14,12 @@
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: } // CHECK: }
func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> %0 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// CHECK-LABEL: func @elementwise$binary( // CHECK-LABEL: func.func @elementwise$binary(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -41,23 +41,23 @@ func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// CHECK-LABEL: func @elementwise$ternary( // CHECK-LABEL: func.func @elementwise$ternary(
// CHECK: linalg.generic {indexing_maps = [ // CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2)>, // CHECK-SAME: affine_map<(d0, d1, d2) -> (d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.lerp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten.lerp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32>
} }
// CHECK-LABEL: func @elementwise$with_scalar_capture( // CHECK-LABEL: func.func @elementwise$with_scalar_capture(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1:.*]] = torch.constant.int 1
@ -69,18 +69,18 @@ func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK: %[[RES:.*]] = arith.addf %[[LHS]], %[[SCALED]] : f32 // CHECK: %[[RES:.*]] = arith.addf %[[LHS]], %[[SCALED]] : f32
// CHECK: linalg.yield %[[RES]] : f32 // CHECK: linalg.yield %[[RES]] : f32
// CHECK: } -> tensor<?xf32> // CHECK: } -> tensor<?xf32>
func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32> %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32>
} }
// CHECK-LABEL: func @elementwise$static_1( // CHECK-LABEL: func.func @elementwise$static_1(
// CHECK: linalg.generic {indexing_maps = [ // CHECK: linalg.generic {indexing_maps = [
// CHECK-SAME: affine_map<(d0) -> (d0)>, // CHECK-SAME: affine_map<(d0) -> (d0)>,
// CHECK-SAME: affine_map<(d0) -> (0)>, // CHECK-SAME: affine_map<(d0) -> (0)>,
// CHECK-SAME: affine_map<(d0) -> (d0)>] // CHECK-SAME: affine_map<(d0) -> (d0)>]
func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> { func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?],f32> {
%1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32> %1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32>
return %1 : !torch.vtensor<[?],f32> return %1 : !torch.vtensor<[?],f32>
} }

View File

@ -2,7 +2,7 @@
// ----- // -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic( // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$basic(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
@ -10,7 +10,7 @@
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { func.func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32> %0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
@ -19,7 +19,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// ----- // -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative( // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
@ -27,7 +27,7 @@ func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> { func.func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int-5 = torch.constant.int -5 %int-5 = torch.constant.int -5
%int-3 = torch.constant.int -3 %int-3 = torch.constant.int -3
%0 = torch.aten.flatten.using_ints %arg0, %int-5, %int-3 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32> %0 = torch.aten.flatten.using_ints %arg0, %int-5, %int-3 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
@ -36,7 +36,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// ----- // -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front( // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
@ -44,7 +44,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int2 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> %0 = torch.aten.flatten.using_ints %arg0, %int0, %int2 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
@ -53,7 +53,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// ----- // -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back( // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
@ -61,7 +61,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> { func.func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
%0 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12],f32> %0 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12],f32>
@ -70,14 +70,14 @@ func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2]
// ----- // -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0( // CHECK-LABEL: func.func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func.func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> %0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32>

View File

@ -3,61 +3,61 @@
// ----- // -----
// CHECK-LABEL: func @torch.aten.unsqueeze$basic( // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32> %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32>
} }
// CHECK-LABEL: func @torch.aten.unsqueeze$basic_negative( // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic_negative(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { func.func @torch.aten.unsqueeze$basic_negative(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32> %0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32>
} }
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_front( // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_front(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3]] : tensor<2x3x4xf32> into tensor<1x2x3x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,3,4],f32>
func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { func.func @torch.aten.unsqueeze$higher_rank_front(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[1,2,3,4],f32> %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[1,2,3,4],f32>
return %0 : !torch.vtensor<[1,2,3,4],f32> return %0 : !torch.vtensor<[1,2,3,4],f32>
} }
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_back( // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_back(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x4x1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x4x1xf32> -> !torch.vtensor<[2,3,4,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4,1],f32>
func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> { func.func @torch.aten.unsqueeze$higher_rank_back(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,1],f32> {
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
%0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,4,1],f32> %0 = torch.aten.unsqueeze %arg0, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,4,1],f32>
return %0 : !torch.vtensor<[2,3,4,1],f32> return %0 : !torch.vtensor<[2,3,4,1],f32>
} }
// CHECK-LABEL: func @torch.aten.unsqueeze$higher_rank_middle( // CHECK-LABEL: func.func @torch.aten.unsqueeze$higher_rank_middle(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<2x3x4xf32> into tensor<2x3x1x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3x1x4xf32> -> !torch.vtensor<[2,3,1,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,1,4],f32>
func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> { func.func @torch.aten.unsqueeze$higher_rank_middle(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1,4],f32> {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,1,4],f32> %0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int -> !torch.vtensor<[2,3,1,4],f32>
return %0 : !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32>

View File

@ -1,6 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-scf | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-scf | FileCheck %s
// CHECK-LABEL: func @torch.prim.if( // CHECK-LABEL: func.func @torch.prim.if(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = torch.constant.int 2
@ -14,7 +14,7 @@
// CHECK: } // CHECK: }
// CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]] // CHECK: %[[VAL_7:.*]] = torch_c.from_i64 %[[VAL_8:.*]]
// CHECK: return %[[VAL_7]] : !torch.int // CHECK: return %[[VAL_7]] : !torch.int
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.prim.If %arg0 -> (!torch.int) { %0 = torch.prim.If %arg0 -> (!torch.int) {
@ -25,7 +25,7 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @aten.prim.if$nested( // CHECK-LABEL: func.func @aten.prim.if$nested(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool, // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int { // CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]]
@ -48,7 +48,7 @@ func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
// CHECK: } // CHECK: }
// CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]] // CHECK: %[[VAL_13:.*]] = torch_c.from_i64 %[[VAL_14:.*]]
// CHECK: return %[[VAL_13]] : !torch.int // CHECK: return %[[VAL_13]] : !torch.int
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int { func.func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
@ -65,7 +65,7 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.prim.loop$while // CHECK-LABEL: func.func @torch.prim.loop$while
// CHECK-SAME: (%[[ARG0:.*]]: !torch.int) -> !torch.float { // CHECK-SAME: (%[[ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[TORCH_FLOAT_VAL:.*]] = torch.constant.float // CHECK: %[[TORCH_FLOAT_VAL:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL]] // CHECK-NEXT: %[[FLOAT_VAL:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL]]
@ -86,7 +86,7 @@ func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: %[[TORCH_LOOP:.*]] = torch_c.from_f64 %[[LOOP]] // CHECK-NEXT: %[[TORCH_LOOP:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[TORCH_LOOP]] : !torch.float // CHECK-NEXT: return %[[TORCH_LOOP]] : !torch.float
func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float { func.func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
%float3.200000e00 = torch.constant.float 3.200000e+00 %float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807 %int9223372036854775807 = torch.constant.int 9223372036854775807
%0 = torch.aten.lt.float_int %float3.200000e00, %arg0 : !torch.float, !torch.int -> !torch.bool %0 = torch.aten.lt.float_int %float3.200000e00, %arg0 : !torch.float, !torch.int -> !torch.bool
@ -99,7 +99,7 @@ func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
return %1 : !torch.float return %1 : !torch.float
} }
// CHECK-LABEL: func @torch.prim.loop$while_with_multiple_values // CHECK-LABEL: func.func @torch.prim.loop$while_with_multiple_values
// CHECK-SAME: () -> (!torch.float, !torch.float) { // CHECK-SAME: () -> (!torch.float, !torch.float) {
// CHECK: %[[TORCH_FLOAT_VAL_0:.*]] = torch.constant.float // CHECK: %[[TORCH_FLOAT_VAL_0:.*]] = torch.constant.float
// CHECK-NEXT: %[[FLOAT_VAL_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_0]] // CHECK-NEXT: %[[FLOAT_VAL_0:.*]] = torch_c.to_f64 %[[TORCH_FLOAT_VAL_0]]
@ -127,7 +127,7 @@ func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float {
// CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0 // CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0
// CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 // CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float // CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float
func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) { func.func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) {
%float3.200000e00 = torch.constant.float 3.200000e+00 %float3.200000e00 = torch.constant.float 3.200000e+00
%int9223372036854775807 = torch.constant.int 9223372036854775807 %int9223372036854775807 = torch.constant.int 9223372036854775807
%float9.0 = torch.constant.float 9.0 %float9.0 = torch.constant.float 9.0
@ -143,7 +143,7 @@ func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.floa
return %1#0, %1#1 : !torch.float, !torch.float return %1#0, %1#1 : !torch.float, !torch.float
} }
// CHECK-LABEL: func @torch.prim.Loop$for // CHECK-LABEL: func.func @torch.prim.Loop$for
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> !torch.float { // CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> !torch.float {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]] // CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true // CHECK-NEXT: %{{.*}} = torch.constant.bool true
@ -164,7 +164,7 @@ func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.floa
// CHECK-NEXT: %[[RETURN:.*]] = torch_c.from_f64 %[[LOOP]] // CHECK-NEXT: %[[RETURN:.*]] = torch_c.from_f64 %[[LOOP]]
// CHECK-NEXT: return %[[RETURN]] : !torch.float // CHECK-NEXT: return %[[RETURN]] : !torch.float
// CHECK-NEXT: } // CHECK-NEXT: }
func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float { func.func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
%true = torch.constant.bool true %true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00 %float0.000000e00 = torch.constant.float 0.000000e+00
%0 = torch.prim.Loop %arg0, %true, init(%float0.000000e00) { %0 = torch.prim.Loop %arg0, %true, init(%float0.000000e00) {
@ -175,7 +175,7 @@ func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func @torch.prim.Loop$for_with_multiple_results // CHECK-LABEL: func.func @torch.prim.Loop$for_with_multiple_results
// CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> (!torch.float, !torch.float) { // CHECK-SAME: (%[[TORCH_ARG0:.*]]: !torch.int) -> (!torch.float, !torch.float) {
// CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]] // CHECK: %[[ARG0:.*]] = torch_c.to_i64 %[[TORCH_ARG0]]
// CHECK-NEXT: %{{.*}} = torch.constant.bool true // CHECK-NEXT: %{{.*}} = torch.constant.bool true
@ -202,7 +202,7 @@ func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float {
// CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 // CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1
// CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float // CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float
// CHECK-NEXT: } // CHECK-NEXT: }
func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) { func.func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) {
%true = torch.constant.bool true %true = torch.constant.bool true
%float0.000000e00 = torch.constant.float 0.000000e+00 %float0.000000e00 = torch.constant.float 0.000000e+00
%float9.0 = torch.constant.float 9.0 %float9.0 = torch.constant.float 9.0

View File

@ -1,19 +1,19 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-std | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-std | FileCheck %s
// CHECK-LABEL: func @torch.aten.dim( // CHECK-LABEL: func.func @torch.aten.dim(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: %[[RANK:.*]] = tensor.rank %[[BUILTIN_TENSOR]] : tensor<*xf32> // CHECK: %[[RANK:.*]] = tensor.rank %[[BUILTIN_TENSOR]] : tensor<*xf32>
// CHECK: %[[RANK_I64:.*]] = arith.index_cast %[[RANK]] : index to i64 // CHECK: %[[RANK_I64:.*]] = arith.index_cast %[[RANK]] : index to i64
// CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]] // CHECK: %[[RANK_TORCH_INT:.*]] = torch_c.from_i64 %[[RANK_I64]]
// CHECK: return %[[RANK_TORCH_INT]] : !torch.int // CHECK: return %[[RANK_TORCH_INT]] : !torch.int
func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { func.func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
%0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int %0 = torch.aten.dim %arg0 : !torch.vtensor<*,f32> -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.runtime.assert( // CHECK-LABEL: func.func @torch.runtime.assert(
// CHECK-SAME: %[[X:.*]]: !torch.int, // CHECK-SAME: %[[X:.*]]: !torch.int,
// CHECK-SAME: %[[Y:.*]]: !torch.int) { // CHECK-SAME: %[[Y:.*]]: !torch.int) {
// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]] // CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]]
@ -21,13 +21,13 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64 // CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64
// CHECK: assert %[[CMP]], "x must not be equal to y" // CHECK: assert %[[CMP]], "x must not be equal to y"
// CHECK: return // CHECK: return
func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) { func.func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
torch.runtime.assert %0, "x must not be equal to y" torch.runtime.assert %0, "x must not be equal to y"
return return
} }
// CHECK-LABEL: func @torch.aten.ne.int( // CHECK-LABEL: func.func @torch.aten.ne.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -35,12 +35,12 @@ func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.eq.int( // CHECK-LABEL: func.func @torch.aten.eq.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -48,12 +48,12 @@ func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.eq.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.eq.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.gt.int( // CHECK-LABEL: func.func @torch.aten.gt.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -61,48 +61,48 @@ func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool %0 = torch.aten.gt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32> // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32> // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32> // CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32> %0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// CHECK-LABEL: func @torch.constant.bool() -> !torch.bool { // CHECK-LABEL: func.func @torch.constant.bool() -> !torch.bool {
// CHECK: %[[CST:.*]] = arith.constant true // CHECK: %[[CST:.*]] = arith.constant true
// CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]] // CHECK: %[[BOOL:.*]] = torch_c.from_i1 %[[CST]]
// CHECK: return %[[BOOL]] : !torch.bool // CHECK: return %[[BOOL]] : !torch.bool
func @torch.constant.bool() -> !torch.bool { func.func @torch.constant.bool() -> !torch.bool {
%true = torch.constant.bool true %true = torch.constant.bool true
return %true : !torch.bool return %true : !torch.bool
} }
// CHECK-LABEL: func @torch.constant.float() -> !torch.float { // CHECK-LABEL: func.func @torch.constant.float() -> !torch.float {
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f64 // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f64
// CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]] // CHECK: %[[FLOAT:.*]] = torch_c.from_f64 %[[CST]]
// CHECK: return %[[FLOAT]] : !torch.float // CHECK: return %[[FLOAT]] : !torch.float
func @torch.constant.float() -> !torch.float { func.func @torch.constant.float() -> !torch.float {
%float = torch.constant.float 1.000000e+00 %float = torch.constant.float 1.000000e+00
return %float : !torch.float return %float : !torch.float
} }
// CHECK-LABEL: func @torch.constant.int() -> !torch.int { // CHECK-LABEL: func.func @torch.constant.int() -> !torch.int {
// CHECK: %[[CST:.*]] = arith.constant 1 : i64 // CHECK: %[[CST:.*]] = arith.constant 1 : i64
// CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]] // CHECK: %[[INT:.*]] = torch_c.from_i64 %[[CST]]
// CHECK: return %[[INT]] : !torch.int // CHECK: return %[[INT]] : !torch.int
func @torch.constant.int() -> !torch.int { func.func @torch.constant.int() -> !torch.int {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
return %int1 : !torch.int return %int1 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.add.int( // CHECK-LABEL: func.func @torch.aten.add.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -110,12 +110,12 @@ func @torch.constant.int() -> !torch.int {
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[ADD:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]] // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int // CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { func.func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int %0 = torch.aten.add.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.sub.int( // CHECK-LABEL: func.func @torch.aten.sub.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -123,12 +123,12 @@ func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[SUB:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]] // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int // CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { func.func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.sub.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int %0 = torch.aten.sub.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.sub.float( // CHECK-LABEL: func.func @torch.aten.sub.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -136,12 +136,12 @@ func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.subf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 // CHECK: %[[SUB:.*]] = arith.subf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]] // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
// CHECK: return %[[OUT:.*]] : !torch.float // CHECK: return %[[OUT:.*]] : !torch.float
func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float { func.func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.sub.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float %0 = torch.aten.sub.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func @torch.aten.mul.int( // CHECK-LABEL: func.func @torch.aten.mul.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] // CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
@ -149,12 +149,12 @@ func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[MUL:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[MUL:.*]] // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[MUL:.*]]
// CHECK: return %[[OUT:.*]] : !torch.int // CHECK: return %[[OUT:.*]] : !torch.int
func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
%0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int %0 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.div.float( // CHECK-LABEL: func.func @torch.aten.div.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -162,12 +162,12 @@ func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {
// CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 // CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]] // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
// CHECK: return %[[OUT:.*]] : !torch.float // CHECK: return %[[OUT:.*]] : !torch.float
func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float { func.func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float %0 = torch.aten.div.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func @torch.aten.ge.float( // CHECK-LABEL: func.func @torch.aten.ge.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -175,12 +175,12 @@ func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.f
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool { func.func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bool {
%0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool %0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ge.float_int( // CHECK-LABEL: func.func @torch.aten.ge.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -189,12 +189,12 @@ func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bo
// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool %0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ne.float_int( // CHECK-LABEL: func.func @torch.aten.ne.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -203,24 +203,24 @@ func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.
// CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.ne.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool %0 = torch.aten.ne.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }
// CHECK-LABEL: func @torch.aten.ceil.float( // CHECK-LABEL: func.func @torch.aten.ceil.float(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { // CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int {
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] // CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64 // CHECK: %[[CEIL:.*]] = math.ceil %[[ARG_F64]] : f64
// CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64 // CHECK: %[[CEIL_I64:.*]] = arith.fptosi %[[CEIL]] : f64 to i64
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]] // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[CEIL_I64]]
// CHECK: return %[[OUT]] : !torch.int // CHECK: return %[[OUT]] : !torch.int
func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int { func.func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
%0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int %0 = torch.aten.ceil.float %arg0 : !torch.float -> !torch.int
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func @torch.aten.gt.float_int( // CHECK-LABEL: func.func @torch.aten.gt.float_int(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]]
@ -229,7 +229,7 @@ func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int {
// CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]]
// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool
func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { func.func @torch.aten.gt.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool {
%0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool %0 = torch.aten.gt.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool
return %0 : !torch.bool return %0 : !torch.bool
} }

View File

@ -1,38 +1,38 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.tanh$basic( // CHECK-LABEL: func.func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.sigmoid$basic( // CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sigmoid"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sigmoid"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.relu$basic( // CHECK-LABEL: func.func @torch.aten.relu$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
@ -40,100 +40,100 @@ func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<
// ----- // -----
// CHECK-LABEL: func @torch.aten.log$basic( // CHECK-LABEL: func.func @torch.aten.log$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.log"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.log"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.exp$basic( // CHECK-LABEL: func.func @torch.aten.exp$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.exp"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.exp"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.neg$basic( // CHECK-LABEL: func.func @torch.aten.neg$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.negate"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.negate"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.floor$basic( // CHECK-LABEL: func.func @torch.aten.floor$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.floor"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.floor"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.floor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.floor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.bitwise_not$basic( // CHECK-LABEL: func.func @torch.aten.bitwise_not$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.bitwise_not"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.bitwise_not"(%[[ARG_BUILTIN]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.ceil$basic( // CHECK-LABEL: func.func @torch.aten.ceil$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.reciprocal$basic( // CHECK-LABEL: func.func @torch.aten.reciprocal$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.add$basic( // CHECK-LABEL: func.func @torch.aten.add$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -145,7 +145,7 @@ func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32> %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32>
@ -153,7 +153,7 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// ----- // -----
// CHECK-LABEL: func @torch.aten.sub$basic( // CHECK-LABEL: func.func @torch.aten.sub$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -165,7 +165,7 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32> %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32>
@ -173,7 +173,7 @@ func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// ----- // -----
// CHECK-LABEL: func @torch.aten.mul$basic( // CHECK-LABEL: func.func @torch.aten.mul$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -181,14 +181,14 @@ func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.div$basic( // CHECK-LABEL: func.func @torch.aten.div$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -197,14 +197,14 @@ func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @test_reduce_mean_dim$basic( // CHECK-LABEL: func.func @test_reduce_mean_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG1:.*]] = torch.constant.int 0
@ -217,7 +217,7 @@ func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%dim0 = torch.constant.int 0 %dim0 = torch.constant.int 0
%reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int> %reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%keepdims = torch.constant.bool false %keepdims = torch.constant.bool false
@ -228,7 +228,7 @@ func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// ----- // -----
// CHECK-LABEL: func @test_reduce_sum_dims$basic( // CHECK-LABEL: func.func @test_reduce_sum_dims$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none // CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
@ -239,7 +239,7 @@ func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -250,7 +250,7 @@ func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// ----- // -----
// CHECK-LABEL: func @test_reduce_sum$basic( // CHECK-LABEL: func.func @test_reduce_sum$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none // CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none
@ -261,7 +261,7 @@ func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32> %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32>
@ -269,7 +269,7 @@ func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vten
// ----- // -----
// CHECK-LABEL: func @test_reduce_all$basic( // CHECK-LABEL: func.func @test_reduce_all$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1> // CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
@ -279,14 +279,14 @@ func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vten
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1> %0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1> return %0 : !torch.vtensor<[1],i1>
} }
// ----- // -----
// CHECK-LABEL: func @test_reduce_any_dim$basic( // CHECK-LABEL: func.func @test_reduce_any_dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG1:.*]] = torch.constant.int 0
@ -295,7 +295,7 @@ func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1> %0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1>
@ -304,7 +304,7 @@ func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @test_reduce_any$basic( // CHECK-LABEL: func.func @test_reduce_any$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1> // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor<?x?x?x?xi1>
// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1> // CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
@ -314,28 +314,28 @@ func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.v
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
%0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1> %0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1>
return %0 : !torch.vtensor<[1],i1> return %0 : !torch.vtensor<[1],i1>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.rsqrt$basic( // CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.maximum$basic( // CHECK-LABEL: func.func @torch.aten.maximum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -344,14 +344,14 @@ func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.minimum$basic( // CHECK-LABEL: func.func @torch.aten.minimum$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -360,14 +360,14 @@ func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic( // CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -376,7 +376,7 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 3.123400e+00 %fp0 = torch.constant.float 3.123400e+00
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
@ -384,7 +384,7 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
// ----- // -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic( // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -396,7 +396,7 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00 %other = torch.constant.float 3.123400e+00
%alpha = torch.constant.float 6.432100e+00 %alpha = torch.constant.float 6.432100e+00
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32> %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32>
@ -405,7 +405,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic( // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
@ -417,7 +417,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00 %other = torch.constant.float 3.123400e+00
%alpha = torch.constant.int 1 %alpha = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32>
@ -426,7 +426,7 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @torch.aten.gt.Tensor$basic( // CHECK-LABEL: func.func @torch.aten.gt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -435,14 +435,14 @@ func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: } // CHECK: }
func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { func.func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.lt.Tensor$basic( // CHECK-LABEL: func.func @torch.aten.lt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -451,14 +451,14 @@ func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: } // CHECK: }
func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { func.func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.eq.Tensor$basic( // CHECK-LABEL: func.func @torch.aten.eq.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -467,14 +467,14 @@ func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: } // CHECK: }
func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.reshape$basic( // CHECK-LABEL: func.func @torch.aten.reshape$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_2:.*]] = torch.constant.int -1
@ -483,7 +483,7 @@ func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
// CHECK: } // CHECK: }
func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
%dim0 = torch.constant.int -1 %dim0 = torch.constant.int -1
%shape = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int> %shape = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list<int>
%0 = torch.aten.reshape %arg0, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32> %0 = torch.aten.reshape %arg0, %shape : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
@ -492,7 +492,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @torch.aten.native_batch_norm$basic( // CHECK-LABEL: func.func @torch.aten.native_batch_norm$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32>
@ -515,7 +515,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32>
// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32>
// CHECK: } // CHECK: }
func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
%0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32> %0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32> %1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%float1.000000e-01 = torch.constant.float 1.000000e-01 %float1.000000e-01 = torch.constant.float 1.000000e-01
@ -528,7 +528,7 @@ func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -
// ----- // -----
// CHECK-LABEL: func @forward( // CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_2:.*]] = torch.constant.int 4
@ -538,7 +538,7 @@ func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
// CHECK: } // CHECK: }
func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> { func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32> %0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
@ -547,7 +547,7 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
// ----- // -----
// CHECK-LABEL: func @forward( // CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> {
@ -584,7 +584,7 @@ func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,
// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32>
// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32>
// CHECK: } // CHECK: }
func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> {
%float5.000000e-01 = torch.constant.float 5.000000e-01 %float5.000000e-01 = torch.constant.float 5.000000e-01
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -595,7 +595,7 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
// ----- // -----
// CHECK-LABEL: func @torch.aten.ne.Tensor$basic( // CHECK-LABEL: func.func @torch.aten.ne.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@ -605,14 +605,14 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1>
// CHECK: } // CHECK: }
func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1> return %0 : !torch.vtensor<[?,?],i1>
} }
// ----- // -----
// CHECK-LABEL: func @forward( // CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
@ -624,7 +624,7 @@ func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32>
// CHECK: } // CHECK: }
func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> { func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -635,7 +635,7 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
// ----- // -----
// CHECK-LABEL: func @torch.aten.bitwise_and.Tensor$basic( // CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32> // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
@ -644,14 +644,14 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: } // CHECK: }
func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> %0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32> return %0 : !torch.vtensor<[?,?],si32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.log2$basic( // CHECK-LABEL: func.func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
@ -661,14 +661,14 @@ func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %ar
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-LABEL: func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4 // CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3 // CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_2:.*]] = torch.constant.none
@ -678,7 +678,7 @@ func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: } // CHECK: }
func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%none = torch.constant.none %none = torch.constant.none
@ -689,7 +689,7 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// ----- // -----
// CHECK-LABEL: func @torch.aten.unsqueeze$basic( // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
@ -698,7 +698,7 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32>
// CHECK: } // CHECK: }
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> { func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32> %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32>
return %0 : !torch.vtensor<[4,1,3],si32> return %0 : !torch.vtensor<[4,1,3],si32>
@ -706,14 +706,14 @@ func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @torch.aten.contiguous$basic( // CHECK-LABEL: func.func @torch.aten.contiguous$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
@ -721,7 +721,7 @@ func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.v
// ----- // -----
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-LABEL: func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4 // CHECK: %[[VAL_0:.*]] = torch.constant.int 4
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3 // CHECK: %[[VAL_1:.*]] = torch.constant.int 3
// CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_2:.*]] = torch.constant.none
@ -731,7 +731,7 @@ func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.v
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
// CHECK: } // CHECK: }
func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%none = torch.constant.none %none = torch.constant.none
@ -742,7 +742,7 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// ----- // -----
// CHECK-LABEL: func @torch.aten.dropout$basic( // CHECK-LABEL: func.func @torch.aten.dropout$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00
@ -751,7 +751,7 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%float0.000000e00 = torch.constant.float 0.000000e+00 %float0.000000e00 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> %0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>

View File

@ -4,31 +4,31 @@ torch.class_type @c {
torch.attr "float" : !torch.float torch.attr "float" : !torch.float
torch.method "calls_free_function", @calls_free_function torch.method "calls_free_function", @calls_free_function
} }
// CHECK-LABEL: func private // CHECK-LABEL: func.func private
// CHECK-SAME: @free_function$[[$MONOMORPHIZE_TAG0:.*]]( // CHECK-SAME: @free_function$[[$MONOMORPHIZE_TAG0:.*]](
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: return %[[F]] : !torch.float // CHECK: return %[[F]] : !torch.float
// CHECK: } // CHECK: }
func private @free_function(%arg0: !torch.float, %arg1: !torch.nn.Module<"c">) -> !torch.float { func.func private @free_function(%arg0: !torch.float, %arg1: !torch.nn.Module<"c">) -> !torch.float {
return %arg0 : !torch.float return %arg0 : !torch.float
} }
// CHECK-LABEL: func private // CHECK-LABEL: func.func private
// CHECK-SAME: @free_function_no_module_args$[[$MONOMORPHIZE_TAG1:.*]]( // CHECK-SAME: @free_function_no_module_args$[[$MONOMORPHIZE_TAG1:.*]](
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: return %[[F]] : !torch.float // CHECK: return %[[F]] : !torch.float
// CHECK: } // CHECK: }
func private @free_function_no_module_args(%arg0: !torch.float) -> !torch.float { func.func private @free_function_no_module_args(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float return %arg0 : !torch.float
} }
// CHECK-LABEL: func @calls_free_function() -> !torch.float { // CHECK-LABEL: func.func @calls_free_function() -> !torch.float {
// CHECK: %[[F1:.*]] = torch.global_slot.get @float : !torch.float // CHECK: %[[F1:.*]] = torch.global_slot.get @float : !torch.float
// CHECK: %[[F2:.*]] = call @free_function$[[$MONOMORPHIZE_TAG0]](%[[F1]]) : (!torch.float) -> !torch.float // CHECK: %[[F2:.*]] = call @free_function$[[$MONOMORPHIZE_TAG0]](%[[F1]]) : (!torch.float) -> !torch.float
// CHECK: %[[RET:.*]] = call @free_function_no_module_args$[[$MONOMORPHIZE_TAG1]](%[[F2]]) : (!torch.float) -> !torch.float // CHECK: %[[RET:.*]] = call @free_function_no_module_args$[[$MONOMORPHIZE_TAG1]](%[[F2]]) : (!torch.float) -> !torch.float
// CHECK: return %[[RET]] : !torch.float // CHECK: return %[[RET]] : !torch.float
// CHECK: } // CHECK: }
func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> !torch.float { func.func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
%1 = call @free_function(%0, %arg0) : (!torch.float, !torch.nn.Module<"c">) -> !torch.float %1 = call @free_function(%0, %arg0) : (!torch.float, !torch.nn.Module<"c">) -> !torch.float
%2 = call @free_function_no_module_args(%1) : (!torch.float) -> !torch.float %2 = call @free_function_no_module_args(%1) : (!torch.float) -> !torch.float

View File

@ -7,28 +7,28 @@ torch.class_type @c {
torch.method "test_call", @test_call torch.method "test_call", @test_call
} }
// CHECK-LABEL: func @test_get() -> !torch.float { // CHECK-LABEL: func.func @test_get() -> !torch.float {
// CHECK: %[[V:.*]] = torch.global_slot.get @float : !torch.float // CHECK: %[[V:.*]] = torch.global_slot.get @float : !torch.float
// CHECK: return %[[V]] : !torch.float // CHECK: return %[[V]] : !torch.float
func private @test_get(%arg0: !torch.nn.Module<"c">) -> !torch.float { func.func private @test_get(%arg0: !torch.nn.Module<"c">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func @test_set( // CHECK-LABEL: func.func @test_set(
// CHECK-SAME: %[[A:.*]]: !torch.float) { // CHECK-SAME: %[[A:.*]]: !torch.float) {
// CHECK: torch.global_slot.set @float = %[[A]] : !torch.float // CHECK: torch.global_slot.set @float = %[[A]] : !torch.float
// CHECK: return // CHECK: return
func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) { func.func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) {
torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, !torch.float torch.prim.SetAttr %arg0["float"] = %arg1 : !torch.nn.Module<"c">, !torch.float
return return
} }
// CHECK-LABEL: func @test_call( // CHECK-LABEL: func.func @test_call(
// CHECK-SAME: %[[A:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[A:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (!torch.float) -> !torch.float // CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (!torch.float) -> !torch.float
// CHECK: return %[[V]] : !torch.float // CHECK: return %[[V]] : !torch.float
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float { func.func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = call @test_call(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float %0 = call @test_call(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }

View File

@ -4,7 +4,7 @@ torch.class_type @parent {
torch.method "module_type_return", @module_type_return torch.method "module_type_return", @module_type_return
} }
func private @module_type_return(%arg0: !torch.nn.Module<"parent">) { func.func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
// expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}} // expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}}
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<nn.Module<"parent">> torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<nn.Module<"parent">>
return return

View File

@ -10,8 +10,8 @@ torch.class_type @parent {
torch.method "method_call", @method_call torch.method "method_call", @method_call
} }
// CHECK-LABEL: func @get_attr_returns_module_type() -> !torch.float { // CHECK-LABEL: func.func @get_attr_returns_module_type() -> !torch.float {
func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) -> !torch.float { func.func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) -> !torch.float {
%0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child"> %0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
// CHECK-NEXT: %[[V:.*]] = torch.global_slot.get @m.float : !torch.float // CHECK-NEXT: %[[V:.*]] = torch.global_slot.get @m.float : !torch.float
%1 = torch.prim.GetAttr %0["float"] : !torch.nn.Module<"child"> -> !torch.float %1 = torch.prim.GetAttr %0["float"] : !torch.nn.Module<"child"> -> !torch.float
@ -21,15 +21,15 @@ func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) ->
return %1 : !torch.float return %1 : !torch.float
} }
// CHECK-LABEL: func @module_type_argument( // CHECK-LABEL: func.func @module_type_argument(
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.none { // CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.none {
func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"parent">, %arg2: !torch.float, %arg3: !torch.nn.Module<"parent">) -> !torch.none { func.func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"parent">, %arg2: !torch.float, %arg3: !torch.nn.Module<"parent">) -> !torch.none {
%0 = torch.constant.none %0 = torch.constant.none
return %0 : !torch.none return %0 : !torch.none
} }
// CHECK-LABEL: func @method_call() -> !torch.none { // CHECK-LABEL: func.func @method_call() -> !torch.none {
func private @method_call(%arg0: !torch.nn.Module<"parent">) -> !torch.none { func.func private @method_call(%arg0: !torch.nn.Module<"parent">) -> !torch.none {
// CHECK-NEXT: %[[C:.*]] = torch.constant.float 4.300000e+01 // CHECK-NEXT: %[[C:.*]] = torch.constant.float 4.300000e+01
%c = torch.constant.float 43.0 %c = torch.constant.float 43.0
// CHECK-NEXT: %[[F:.*]] = call @module_type_argument(%[[C]]) : (!torch.float) -> !torch.none // CHECK-NEXT: %[[F:.*]] = call @module_type_argument(%[[C]]) : (!torch.float) -> !torch.none

View File

@ -1,9 +1,9 @@
// RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s // RUN: torch-mlir-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) { func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
return return
} }
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) { func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%5 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %5 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%6 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %6 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.Submodule.forward(%5, %6) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.Submodule.forward(%5, %6) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()

View File

@ -34,10 +34,10 @@ torch.class_type @__torch__.Submodule {
} : !torch.nn.Module<"__torch__.TestModule"> } : !torch.nn.Module<"__torch__.TestModule">
// CHECK-LABEL: func @forward() { // CHECK-LABEL: func.func @forward() {
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG0:.*]]() : () -> () // CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG0:.*]]() : () -> ()
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG1:.*]]() : () -> () // CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG1:.*]]() : () -> ()
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) { func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.free_function(%4, %5) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.free_function(%4, %5) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
@ -48,27 +48,27 @@ func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.Te
} }
// s1 called first, then s2 // s1 called first, then s2
// CHECK-LABEL: func private // CHECK-LABEL: func.func private
// CHECK-SAME @__torch__.free_function$[[$MONOMORPHIZE_TAG0]]() { // CHECK-SAME @__torch__.free_function$[[$MONOMORPHIZE_TAG0]]() {
// CHECK: call @s1.forward() : () -> () // CHECK: call @s1.forward() : () -> ()
// CHECK: call @s2.forward() : () -> () // CHECK: call @s2.forward() : () -> ()
// s2 called first, then s1 // s2 called first, then s1
// CHECK-LABEL: func private // CHECK-LABEL: func.func private
// CHECK-SAME: @__torch__.free_function$[[$MONOMORPHIZE_TAG1]]() { // CHECK-SAME: @__torch__.free_function$[[$MONOMORPHIZE_TAG1]]() {
// CHECK: call @s2.forward() : () -> () // CHECK: call @s2.forward() : () -> ()
// CHECK: call @s1.forward() : () -> () // CHECK: call @s1.forward() : () -> ()
func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) { func.func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
call @__torch__.Submodule.forward(%arg0) : (!torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.Submodule.forward(%arg0) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
call @__torch__.Submodule.forward(%arg1) : (!torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.Submodule.forward(%arg1) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
return return
} }
// CHECK-LABEL: func private @s2.forward() { // CHECK-LABEL: func.func private @s2.forward() {
// CHECK: return // CHECK: return
// CHECK-LABEL: func private @s1.forward() { // CHECK-LABEL: func.func private @s1.forward() {
// CHECK: return // CHECK: return
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) { func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
return return
} }

View File

@ -32,31 +32,31 @@ torch.class_type @__torch__.Submodule {
} : !torch.nn.Module<"__torch__.TestModule"> } : !torch.nn.Module<"__torch__.TestModule">
// CHECK-LABEL: func @forward() { // CHECK-LABEL: func.func @forward() {
// CHECK: call @s1.forward() : () -> () // CHECK: call @s1.forward() : () -> ()
// CHECK: call @s2.forward() : () -> () // CHECK: call @s2.forward() : () -> ()
// CHECK: return // CHECK: return
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) { func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule"> %5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.Submodule.forward(%4) : (!torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.Submodule.forward(%4) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
call @__torch__.Submodule.forward(%5) : (!torch.nn.Module<"__torch__.Submodule">) -> () call @__torch__.Submodule.forward(%5) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
return return
} }
// CHECK-LABEL: func private @s1.forward() { // CHECK-LABEL: func.func private @s1.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s1.n : !torch.int // CHECK: %[[N:.*]] = torch.global_slot.get @s1.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s1.n = %[[NEWVAL]] : !torch.int // CHECK: torch.global_slot.set @s1.n = %[[NEWVAL]] : !torch.int
// CHECK: return // CHECK: return
// CHECK-LABEL: func private @s2.forward() { // CHECK-LABEL: func.func private @s2.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s2.n : !torch.int // CHECK: %[[N:.*]] = torch.global_slot.get @s2.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s2.n = %[[NEWVAL]] : !torch.int // CHECK: torch.global_slot.set @s2.n = %[[NEWVAL]] : !torch.int
// CHECK: return // CHECK: return
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) { func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%5 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int %5 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int
%6 = torch.aten.add.int %5, %int3 : !torch.int, !torch.int -> !torch.int %6 = torch.aten.add.int %5, %int3 : !torch.int, !torch.int -> !torch.int

View File

@ -6,8 +6,8 @@ torch.class_type @c {
torch.method private "forward", @method torch.method private "forward", @method
} }
// CHECK: func private @forward() { // CHECK: func.func private @forward() {
func private @method(%arg0: !torch.nn.Module<"c">) { func.func private @method(%arg0: !torch.nn.Module<"c">) {
return return
} }

View File

@ -1,22 +1,22 @@
// RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-adjust-calling-conventions -allow-unregistered-dialect -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor // CHECK: %[[NONVAL_TENSOR:.*]] = torch.copy.to_tensor %[[ERASED]] : !torch.tensor
// CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor // CHECK: return %[[NONVAL_TENSOR]] : !torch.tensor
func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor { func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
} }
// CHECK-LABEL: func @no_type_bound( // CHECK-LABEL: func.func @no_type_bound(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[ARG]] : !torch.tensor // CHECK: return %[[ARG]] : !torch.tensor
func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor { func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
} }
// CHECK-LABEL: func @call( // CHECK-LABEL: func.func @call(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor // CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
// CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor // CHECK: %[[ARG_NONVAL:.*]] = torch.copy.to_tensor %[[ARG_ERASED]] : !torch.tensor
@ -24,31 +24,31 @@ func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
// CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32> // CHECK: %[[CALL_ARG:.*]] = torch.copy.to_vtensor %[[INFO_ADDED]] : !torch.vtensor<[2,3,?],f32>
// CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor // CHECK: %[[CALL_RES:.*]] = call @call(%[[CALL_ARG]]) : (!torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// CHECK: return %[[ARG_NONVAL]] : !torch.tensor // CHECK: return %[[ARG_NONVAL]] : !torch.tensor
func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor { func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],f32>}) -> !torch.tensor {
%0 = call @call(%arg0) : (!torch.tensor) -> !torch.tensor %0 = call @call(%arg0) : (!torch.tensor) -> !torch.tensor
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
} }
// CHECK-LABEL: func @none_return() { // CHECK-LABEL: func.func @none_return() {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: return // CHECK: return
func @none_return() -> !torch.none { func.func @none_return() -> !torch.none {
%1 = torch.constant.none %1 = torch.constant.none
return %1 : !torch.none return %1 : !torch.none
} }
// CHECK-LABEL: func @none_call_return() { // CHECK-LABEL: func.func @none_call_return() {
// CHECK: call @none_return() : () -> () // CHECK: call @none_return() : () -> ()
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> () // CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> ()
// CHECK: return // CHECK: return
func @none_call_return() { func.func @none_call_return() {
%0 = call @none_return() : () -> !torch.none %0 = call @none_return() : () -> !torch.none
"test.use"(%0) : (!torch.none) -> () "test.use"(%0) : (!torch.none) -> ()
return return
} }
// CHECK-LABEL: func @tuple_return( // CHECK-LABEL: func.func @tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor // CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
@ -64,13 +64,13 @@ func @none_call_return() {
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : // CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor // CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor // CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> { %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor> %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
return %1 : !torch.tuple<tensor, tensor> return %1 : !torch.tuple<tensor, tensor>
} }
// CHECK-LABEL: func @call_tuple_return( // CHECK-LABEL: func.func @call_tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor // CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
@ -92,7 +92,7 @@ func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : // CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor // CHECK-SAME: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor // CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> { %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor> %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor> return %0 : !torch.tuple<tensor, tensor>

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @matmul_no_decompose // CHECK-LABEL: func.func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { func.func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
@ -10,23 +10,23 @@ func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.
// ----- // -----
// CHECK-LABEL: func @matmul_decompose_2d // CHECK-LABEL: func.func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor // CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor { func.func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @matmul_decompose_3d( // CHECK-LABEL: func.func @matmul_decompose_3d(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor // CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int( // CHECK-LABEL: func.func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> { // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -45,7 +45,7 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[2,3],f32> // CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32> // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32> // CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> { func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> {
%dtype = torch.constant.none %dtype = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %dtype: !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> %ret = torch.aten.softmax.int %t, %dim, %dtype: !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
return %ret : !torch.tensor<[2,3],f32> return %ret : !torch.tensor<[2,3],f32>
@ -53,7 +53,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim( // CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -72,7 +72,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32> -> !torch.tensor<[2,3],f32> // CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[2,3],f32>, !torch.tensor<[2,1],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32> // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[2,3],f32> to !torch.tensor<[2,3],f32>
// CHECK: return %[[RET]] : !torch.tensor<[2,3],f32> // CHECK: return %[[RET]] : !torch.tensor<[2,3],f32>
func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
%none = torch.constant.none %none = torch.constant.none
%dim = torch.constant.int 1 %dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32>
@ -80,7 +80,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int$dyn_shape( // CHECK-LABEL: func.func @torch.aten.softmax.int$dyn_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { // CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -99,7 +99,7 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[?,?],f32>, !torch.tensor<[?,1],f32> -> !torch.tensor<[?,?],f32> // CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<[?,?],f32>, !torch.tensor<[?,1],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[?,?],f32> to !torch.tensor<[?,?],f32> // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<[?,?],f32> to !torch.tensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.tensor<[?,?],f32> // CHECK: return %[[RET]] : !torch.tensor<[?,?],f32>
func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { func.func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
%none = torch.constant.none %none = torch.constant.none
%dim = torch.constant.int 1 %dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none -> !torch.tensor<[?,?],f32> %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none -> !torch.tensor<[?,?],f32>
@ -107,7 +107,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int$unknown_shape( // CHECK-LABEL: func.func @torch.aten.softmax.int$unknown_shape(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1 // CHECK: %[[DIM:.*]] = torch.constant.int 1
@ -126,7 +126,7 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<*,f32>, !torch.tensor<*,f32> -> !torch.tensor<*,f32> // CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM]] : !torch.tensor<*,f32>, !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor<*,f32> // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor<*,f32>
// CHECK: return %[[RET]] : !torch.tensor<*,f32> // CHECK: return %[[RET]] : !torch.tensor<*,f32>
func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { func.func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
%none = torch.constant.none %none = torch.constant.none
%dim = torch.constant.int 1 %dim = torch.constant.int 1
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
@ -134,7 +134,7 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.size( // CHECK-LABEL: func.func @torch.aten.size(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<int> { // CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int // CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
@ -142,13 +142,13 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int // CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[SIZE]] : !torch.list<int> // CHECK: return %[[SIZE]] : !torch.list<int>
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> { func.func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<int> %0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.arange() -> !torch.vtensor<[?],si64> { // CHECK-LABEL: func.func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST5:.*]] = torch.constant.int 5 // CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CSTN:.*]] = torch.constant.none // CHECK: %[[CSTN:.*]] = torch.constant.none
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
@ -156,7 +156,7 @@ func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<int> {
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST5]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] : // CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST5]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> // CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange() -> !torch.vtensor<[?],si64> { func.func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
%int5 = torch.constant.int 5 %int5 = torch.constant.int 5
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten.arange %int5, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> %0 = torch.aten.arange %int5, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
@ -164,7 +164,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> { // CHECK-LABEL: func.func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST10:.*]] = torch.constant.int 10 // CHECK: %[[CST10:.*]] = torch.constant.int 10
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CSTN:.*]] = torch.constant.none // CHECK: %[[CSTN:.*]] = torch.constant.none
@ -172,7 +172,7 @@ func @torch.aten.arange() -> !torch.vtensor<[?],si64> {
// CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST10]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] : // CHECK: %[[RESULT:.*]] = torch.aten.arange.start_step %[[CST0]], %[[CST10]], %[[CST1]], %[[CSTN]], %[[CSTN]], %[[CSTN]], %[[CSTN]] :
// CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> // CHECK-SAME: !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],si64>
func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> { func.func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
%int10 = torch.constant.int 10 %int10 = torch.constant.int 10
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%none = torch.constant.none %none = torch.constant.none
@ -181,14 +181,14 @@ func @torch.aten.arange.start() -> !torch.vtensor<[?],si64> {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.argmax( // CHECK-LABEL: func.func @torch.aten.argmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] : // CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[CST0]], %[[TRUE]] :
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64> // CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],f32>, !torch.vtensor<[1,?],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64> // CHECK: return %[[IND]] : !torch.vtensor<[1,?],si64>
func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> { func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?],si64> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64> %0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
@ -196,7 +196,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.argmax$reduceall( // CHECK-LABEL: func.func @torch.aten.argmax$reduceall(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -207,7 +207,7 @@ func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[1,?
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : // CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] :
// CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64> // CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64>
// CHECK: return %[[IND]] : !torch.vtensor<[],si64> // CHECK: return %[[IND]] : !torch.vtensor<[],si64>
func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> { func.func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],si64> {
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64> %0 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
@ -215,18 +215,18 @@ func @torch.aten.argmax$reduceall(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.square( // CHECK-LABEL: func.func @torch.aten.square(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] : // CHECK: %[[SQUARE:.*]] = torch.aten.mul.Tensor %[[INPUT]], %[[INPUT]] :
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> // CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[SQUARE]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten.square %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.var$unbiased( // CHECK-LABEL: func.func @torch.aten.var$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true // CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -242,14 +242,14 @@ func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?
// CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_SUB1:.*]] = torch.aten.sub.int %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> %0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.var$biased( // CHECK-LABEL: func.func @torch.aten.var$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false // CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -263,14 +263,14 @@ func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
// CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int // CHECK: %[[SUB_MEAN_SQUARE_NUM_ELEMENTS:.*]] = torch.aten.numel %[[SUB_MEAN_SQUARE]] : !torch.vtensor<[?,?,?],f32> -> !torch.int
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32> // CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32>
func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> %0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.std$unbiased( // CHECK-LABEL: func.func @torch.aten.std$unbiased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool true // CHECK: %[[UNBIASED:.*]] = torch.constant.bool true
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -287,14 +287,14 @@ func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtenso
// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_SUB1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> %0 = torch.aten.std %arg0, %true: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.std$biased( // CHECK-LABEL: func.func @torch.aten.std$biased(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false // CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
@ -309,20 +309,20 @@ func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> // CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[SUB_MEAN_SQUARE_NUM_ELEMENTS]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> // CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32>
func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> {
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> %0 = torch.aten.std %arg0, %false: !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._unsafe_view$static // CHECK-LABEL: func.func @torch.aten._unsafe_view$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>) // CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,512,32],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view // CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] // CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return // CHECK-NEXT: return
func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> { func.func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !torch.vtensor<[1,2,256,32],f32> {
%c1 = torch.constant.int 1 %c1 = torch.constant.int 1
%c2 = torch.constant.int 2 %c2 = torch.constant.int 2
%c256 = torch.constant.int 256 %c256 = torch.constant.int 256
@ -333,14 +333,14 @@ func @torch.aten._unsafe_view$static(%arg0: !torch.vtensor<[1,512,32],f32>) -> !
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._reshape_alias$static // CHECK-LABEL: func.func @torch.aten._reshape_alias$static
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1],f32>) // CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1],f32>)
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._reshape_alias // CHECK-NOT: torch.aten._reshape_alias
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST1]] // CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST1]]
// CHECK-NEXT: return // CHECK-NEXT: return
func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> { func.func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[12,32],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int32 = torch.constant.int 32 %int32 = torch.constant.int 32
%int12 = torch.constant.int 12 %int12 = torch.constant.int 12
@ -351,13 +351,13 @@ func @torch.aten._reshape_alias$static(%arg0: !torch.vtensor<[1],f32>) -> !torch
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._unsafe_view$dynamic // CHECK-LABEL: func.func @torch.aten._unsafe_view$dynamic
// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) // CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NOT: torch.aten._unsafe_view // CHECK-NOT: torch.aten._unsafe_view
// CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]] // CHECK-NEXT: %[[RES:.*]] = torch.aten.view %[[ARG0]], %[[LIST]]
// CHECK-NEXT: return // CHECK-NEXT: return
func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> { func.func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[512,32],f32> {
%c256 = torch.constant.int 512 %c256 = torch.constant.int 512
%c32 = torch.constant.int 32 %c32 = torch.constant.int 32
%0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<int> %0 = torch.prim.ListConstruct %c256, %c32 : (!torch.int, !torch.int) -> !torch.list<int>
@ -366,7 +366,7 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._log_softmax( // CHECK-LABEL: func.func @torch.aten._log_softmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -386,7 +386,7 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>, // CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>,
// CHECK-SAME: !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32> // CHECK-SAME: !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[SUB1]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[SUB1]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32>
@ -394,7 +394,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.bernoulli // CHECK-LABEL: func.func @torch.aten.bernoulli
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT7:.*]] = torch.constant.int 7 // CHECK: %[[INT7:.*]] = torch.constant.int 7
@ -424,7 +424,7 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func.func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> %0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
@ -432,7 +432,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.float // CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.float
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01 // CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01
@ -463,7 +463,7 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func.func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%prob = torch.constant.float 4.000000e-01 %prob = torch.constant.float 4.000000e-01
%0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> %0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
@ -472,7 +472,7 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
} }
// ----- // -----
// CHECK-LABEL: func @torch.valsem.aten.bernoulli.Tensor( // CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.Tensor(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>, // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>,
// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -502,7 +502,7 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> // CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func.func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> %0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor
@ -510,7 +510,7 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.rand_like( // CHECK-LABEL: func.func @torch.aten.rand_like(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
// CHECK: %[[INT6:.*]] = torch.constant.int 6 // CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[NONE_0:.*]] = torch.constant.none // CHECK: %[[NONE_0:.*]] = torch.constant.none
@ -528,7 +528,7 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar
// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { func.func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
%int6 = torch.constant.int 6 %int6 = torch.constant.int 6
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
@ -537,7 +537,7 @@ func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.select.int( // CHECK-LABEL: func.func @torch.aten.select.int(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { // CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[CST1:.*]] = torch.constant.int 1
@ -547,14 +547,14 @@ func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
// CHECK: %[[SELECT:.*]] = torch.aten.squeeze.dim %[[SLICE]], %[[CST0]] : // CHECK: %[[SELECT:.*]] = torch.aten.squeeze.dim %[[SLICE]], %[[CST0]] :
// CHECK-SAME: !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[?],si64> // CHECK-SAME: !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: return %[[SELECT]] : !torch.vtensor<[?],si64> // CHECK: return %[[SELECT]] : !torch.vtensor<[?],si64>
func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { func.func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64> %0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64> return %0 : !torch.vtensor<[?],si64>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.hardsigmoid( // CHECK-LABEL: func.func @torch.aten.hardsigmoid(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST2:.*]] = torch.constant.int 3 // CHECK: %[[CST2:.*]] = torch.constant.int 3
@ -576,13 +576,13 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
// CHECK: } // CHECK: }
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.hardswish( // CHECK-LABEL: func.func @torch.aten.hardswish(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -599,13 +599,13 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32>
func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.hardtanh( // CHECK-LABEL: func.func @torch.aten.hardtanh(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float, // CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float,
// CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> {
@ -622,13 +622,13 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> // CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32> // CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32> // CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> { func.func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
%0 = torch.aten.hardtanh %arg0, %min, %max : !torch.vtensor<[?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?],f32> %0 = torch.aten.hardtanh %arg0, %min, %max : !torch.vtensor<[?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.new_zeros // CHECK-LABEL: func.func @torch.aten.new_zeros
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -638,7 +638,7 @@ func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %m
// CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> // CHECK: %[[RES:.*]] = torch.aten.zeros %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32> // CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
// CHECK: } // CHECK: }
func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> { func.func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none %none = torch.constant.none
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
@ -648,7 +648,7 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.new_ones // CHECK-LABEL: func.func @torch.aten.new_ones
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -658,7 +658,7 @@ func @torch.aten.new_zeros(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64> // CHECK: %[[RES:.*]] = torch.aten.ones %[[SIZE]], %[[INT4_0]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64> // CHECK: return %[[RES]] : !torch.vtensor<[3,4],si64>
// CHECK: } // CHECK: }
func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> { func.func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[3,4],si64> {
%none = torch.constant.none %none = torch.constant.none
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
@ -668,18 +668,18 @@ func @torch.aten.new_ones(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.silu( // CHECK-LABEL: func.func @torch.aten.silu(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// CHECK: %[[SIGMOID:.*]] = torch.aten.sigmoid %[[INP]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor // CHECK: %[[SIGMOID:.*]] = torch.aten.sigmoid %[[INP]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[SIGMOID]], %[[INP]] : !torch.vtensor, !torch.vtensor<[?,?],f32> -> !torch.vtensor // CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[SIGMOID]], %[[INP]] : !torch.vtensor, !torch.vtensor<[?,?],f32> -> !torch.vtensor
// CHECK: return %[[MUL]] : !torch.vtensor // CHECK: return %[[MUL]] : !torch.vtensor
func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.vtensor { func.func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.vtensor {
%0 = torch.aten.silu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor %0 = torch.aten.silu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.full // CHECK-LABEL: func.func @torch.aten.full
// CHECK-SAME: () -> !torch.vtensor<[2,3],f32> { // CHECK-SAME: () -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00 // CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -690,7 +690,7 @@ func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !torch.v
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32> // CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.full() -> !torch.vtensor<[2,3],f32> { func.func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
%float5.000000e00 = torch.constant.float 5.000000e+00 %float5.000000e00 = torch.constant.float 5.000000e+00
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -701,7 +701,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.full_like( // CHECK-LABEL: func.func @torch.aten.full_like(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT5:.*]] = torch.constant.int 5 // CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -713,7 +713,7 @@ func @torch.aten.full() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int5 = torch.constant.int 5 %int5 = torch.constant.int 5
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> %0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
@ -721,7 +721,7 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.index_put( // CHECK-LABEL: func.func @torch.aten.index_put(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> {
@ -729,14 +729,14 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32> // CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> { func.func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> {
%indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor> %indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
%0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32> %0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.expand_as( // CHECK-LABEL: func.func @torch.aten.expand_as(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int // CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int
@ -747,13 +747,13 @@ func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtens
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32> return %0 : !torch.vtensor<[?,?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._to_copy( // CHECK-LABEL: func.func @torch.aten._to_copy(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -767,7 +767,7 @@ func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vte
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[RES:.*]] = torch.valsem.aten.copy %[[EMPTY]], %[[INP]], %[[FALSE]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[RES:.*]] = torch.valsem.aten.copy %[[EMPTY]], %[[INP]], %[[FALSE]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%false = torch.constant.bool false %false = torch.constant.bool false
%none = torch.constant.none %none = torch.constant.none
%0 = torch.aten._to_copy %arg0, %none, %none, %none, %none, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten._to_copy %arg0, %none, %none, %none, %none, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
@ -775,12 +775,12 @@ func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.dropout$eval( // CHECK-LABEL: func.func @torch.aten.dropout$eval(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[PROB:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[PROB:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[TRAIN:.*]] = torch.constant.bool false // CHECK: %[[TRAIN:.*]] = torch.constant.bool false
// CHECK: return %[[INP:.*]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[INP:.*]] : !torch.vtensor<[?,?],f32>
func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float1.000000e-01 = torch.constant.float 1.000000e-01 %float1.000000e-01 = torch.constant.float 1.000000e-01
%false = torch.constant.bool false %false = torch.constant.bool false
%0 = torch.aten.dropout %arg0, %float1.000000e-01, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> %0 = torch.aten.dropout %arg0, %float1.000000e-01, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
@ -788,8 +788,8 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.dropout$train( // CHECK-LABEL: func.func @torch.aten.dropout$train(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[PROB:.*]] = torch.constant.float 3.000000e-01 // CHECK: %[[PROB:.*]] = torch.constant.float 3.000000e-01
// CHECK: %[[TRAIN:.*]] = torch.constant.bool true // CHECK: %[[TRAIN:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -821,7 +821,7 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso
// CHECK: %[[MASK_INP:.*]] = torch.aten.mul.Tensor %[[BOOL_MASK]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[MASK_INP:.*]] = torch.aten.mul.Tensor %[[BOOL_MASK]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[MASK_INP]], %[[ONEMINUSP]] : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> // CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[MASK_INP]], %[[ONEMINUSP]] : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%float3.000000e-01 = torch.constant.float 3.000000e-01 %float3.000000e-01 = torch.constant.float 3.000000e-01
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.aten.dropout %arg0, %float3.000000e-01, %true : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> %0 = torch.aten.dropout %arg0, %float3.000000e-01, %true : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
@ -829,18 +829,18 @@ func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtens
} }
// ----- // -----
// CHECK-LABEL: func @torch.valsem.aten.zero( // CHECK-LABEL: func.func @torch.valsem.aten.zero(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ZERO:.*]] = torch.constant.int 0 // CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.valsem.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.valsem.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.new_empty // CHECK-LABEL: func.func @torch.aten.new_empty
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -850,7 +850,7 @@ func @torch.valsem.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
// CHECK: %[[INT6:.*]] = torch.constant.int 6 // CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> // CHECK: %[[RES:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32> // CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32>
func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> { func.func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
%none = torch.constant.none %none = torch.constant.none
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
@ -860,7 +860,7 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.where.Scalar( // CHECK-LABEL: func.func @torch.aten.where.Scalar(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> { // CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00 // CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00
// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00 // CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00
@ -874,7 +874,7 @@ func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
// CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> // CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32>
func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> { func.func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> {
%cst8 = torch.constant.float 8.000000e+00 %cst8 = torch.constant.float 8.000000e+00
%cst4 = torch.constant.float 4.000000e+00 %cst4 = torch.constant.float 4.000000e+00
%0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32> %0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32>
@ -882,7 +882,7 @@ func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtens
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.where.ScalarSelf( // CHECK-LABEL: func.func @torch.aten.where.ScalarSelf(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { // CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00 // CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
@ -891,14 +891,14 @@ func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtens
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> // CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { func.func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00 %cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64> %0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64> return %0 : !torch.vtensor<[?,?,?],f64>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.where.ScalarOther( // CHECK-LABEL: func.func @torch.aten.where.ScalarOther(
// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { // CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00 // CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
@ -907,21 +907,21 @@ func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !tor
// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> // CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64> // CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64>
func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { func.func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> {
%cst = torch.constant.float 4.000000e+00 %cst = torch.constant.float 4.000000e+00
%0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64> %0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64>
return %0 : !torch.vtensor<[?,?,?],f64> return %0 : !torch.vtensor<[?,?,?],f64>
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.pad // CHECK-LABEL: func.func @torch.aten.pad
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> { // CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?,?],f64>, %[[VALUE:.*]]: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
// CHECK-NOT: torch.aten.pad // CHECK-NOT: torch.aten.pad
// CHECK: %[[STRING:.*]] = torch.constant.str "constant" // CHECK: %[[STRING:.*]] = torch.constant.str "constant"
// CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct // CHECK-NEXT: %[[LIST:.*]] = torch.prim.ListConstruct
// CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]] // CHECK-NEXT: %[[PAD_ND:.*]] = torch.aten.constant_pad_nd %[[SELF]], %[[LIST]], %[[VALUE]]
// CHECK-NEXT: return %[[PAD_ND]] // CHECK-NEXT: return %[[PAD_ND]]
func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> { func.func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) -> !torch.vtensor<[?,?,?],f64> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -933,7 +933,7 @@ func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) ->
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.to.dtype_layout( // CHECK-LABEL: func.func @torch.aten.to.dtype_layout(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> { // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -941,7 +941,7 @@ func @torch.aten.pad(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.float) ->
// CHECK: %[[CST7:.*]] = torch.constant.int 7 // CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[OUT:.*]] = torch.aten.to.dtype %[[SELF]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[OUT:.*]] = torch.aten.to.dtype %[[SELF]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f64> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f64>
func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> { func.func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f64> {
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0

View File

@ -1,11 +1,11 @@
// RUN: torch-mlir-opt -torch-drop-shape-calculations -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-drop-shape-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[ERASED]] : !torch.vtensor // CHECK: return %[[ERASED]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor { func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.shape.calculate { %0 = torch.shape.calculate {

View File

@ -16,8 +16,8 @@ torch.global_slot "private" @mutated : !torch.tensor {
torch.global_slot.init %0 : !torch.tensor torch.global_slot.init %0 : !torch.tensor
} }
// CHECK-LABEL: func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) { // CHECK-LABEL: func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) { func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
// Inlined. // Inlined.
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor // CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
%0 = torch.global_slot.get @readonly : !torch.tensor %0 = torch.global_slot.get @readonly : !torch.tensor

View File

@ -1,20 +1,20 @@
// RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s // RUN: torch-mlir-opt -split-input-file -allow-unregistered-dialect %s -torch-maximize-value-semantics | FileCheck %s
// CHECK-LABEL: func @torch.copy.tensor$basic( // CHECK-LABEL: func.func @torch.copy.tensor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { func.func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.copy.to_vtensor %0 : !torch.vtensor %1 = torch.copy.to_vtensor %0 : !torch.vtensor
%2 = torch.copy.to_vtensor %0 : !torch.vtensor %2 = torch.copy.to_vtensor %0 : !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor return %1, %2 : !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @one_mutation_in_a_block( // CHECK-LABEL: func.func @one_mutation_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { func.func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
@ -22,11 +22,11 @@ func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @multiple_mutations_in_a_block( // CHECK-LABEL: func.func @multiple_mutations_in_a_block(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, %[[ARG1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor // CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG1]], %[[ARG2]] : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) { func.func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor) {
// The mutable tensor we are overwriting. // The mutable tensor we are overwriting.
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor %tensor = torch.copy.to_tensor %arg0 : !torch.tensor
@ -45,12 +45,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @mutation_followed_by_view_like_ops( // CHECK-LABEL: func.func @mutation_followed_by_view_like_ops(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor { // CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor // CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor // CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<int> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor { func.func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor %t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor %view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
@ -59,10 +59,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
return %value_result : !torch.vtensor return %value_result : !torch.vtensor
} }
// CHECK-LABEL: func @mutation_of_view_like_op_result( // CHECK-LABEL: func.func @mutation_of_view_like_op_result(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor { // CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: return %[[OVERWRITER]] : !torch.vtensor // CHECK: return %[[OVERWRITER]] : !torch.vtensor
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor { func.func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor %t = torch.copy.to_tensor %value_t : !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor %view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<int> -> !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor
@ -70,20 +70,20 @@ func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !to
return %result : !torch.vtensor return %result : !torch.vtensor
} }
// CHECK-LABEL: func @value_tensor_used_after_copy_was_mutated( // CHECK-LABEL: func.func @value_tensor_used_after_copy_was_mutated(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, // CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor,
// CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor
func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { func.func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%t = torch.copy.to_tensor %value_t : !torch.tensor %t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor %value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor
return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @unmodeled_mutation( // CHECK-LABEL: func.func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor.contents // CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> () "some.op"(%0) : (!torch.tensor) -> ()
@ -92,9 +92,9 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch
} }
// We don't yet handle nontrivial cases involving control flow. // We don't yet handle nontrivial cases involving control flow.
// CHECK-LABEL: func @unimplemented_control_flow( // CHECK-LABEL: func.func @unimplemented_control_flow(
// CHECK: torch.copy.to_vtensor // CHECK: torch.copy.to_vtensor
func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) { func.func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %cond: !torch.bool) -> (!torch.vtensor, !torch.vtensor) {
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor %tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () { torch.prim.If %cond -> () {
@ -107,33 +107,33 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @non_value_tensor_returned( // CHECK-LABEL: func.func @non_value_tensor_returned(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor { // CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor // CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor { func.func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor %t = torch.copy.to_tensor %value_t : !torch.tensor
return %t : !torch.tensor return %t : !torch.tensor
} }
// CHECK-LABEL: func @non_value_tensor_returned$with_overwrite( // CHECK-LABEL: func.func @non_value_tensor_returned$with_overwrite(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %{{.*}}: !torch.vtensor) -> !torch.tensor { // CHECK-SAME: %{{.*}}: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[RESULT:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor // CHECK: %[[RESULT:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: return %[[RESULT]] : !torch.tensor // CHECK: return %[[RESULT]] : !torch.tensor
func @non_value_tensor_returned$with_overwrite(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor { func.func @non_value_tensor_returned$with_overwrite(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
%2 = torch.copy.to_tensor %arg1 : !torch.tensor %2 = torch.copy.to_tensor %arg1 : !torch.tensor
torch.overwrite.tensor.contents %arg0 overwrites %2 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg0 overwrites %2 : !torch.vtensor, !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// CHECK-LABEL: func @non_value_tensor_returned$return_from_multiple_slices( // CHECK-LABEL: func.func @non_value_tensor_returned$return_from_multiple_slices(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) {
// CHECK: %[[NON_VALUE_TENSOR0:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor // CHECK: %[[NON_VALUE_TENSOR0:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: %[[NON_VALUE_TENSOR1:.*]] = torch.copy.to_tensor %[[ARG1]] : !torch.tensor // CHECK: %[[NON_VALUE_TENSOR1:.*]] = torch.copy.to_tensor %[[ARG1]] : !torch.tensor
// CHECK: return %[[NON_VALUE_TENSOR0]], %[[ARG0]], %[[NON_VALUE_TENSOR1]] : !torch.tensor, !torch.vtensor, !torch.tensor // CHECK: return %[[NON_VALUE_TENSOR0]], %[[ARG0]], %[[NON_VALUE_TENSOR1]] : !torch.tensor, !torch.vtensor, !torch.tensor
func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) { func.func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.tensor, !torch.vtensor, !torch.tensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
// Make a vtensor copy and return that, just to have a load-bearing use. // Make a vtensor copy and return that, just to have a load-bearing use.
// This test mainly checks the rewriting of the non-value tensor returns // This test mainly checks the rewriting of the non-value tensor returns
@ -143,12 +143,12 @@ func @non_value_tensor_returned$return_from_multiple_slices(%arg0: !torch.vtenso
return %0, %1, %2 : !torch.tensor, !torch.vtensor, !torch.tensor return %0, %1, %2 : !torch.tensor, !torch.vtensor, !torch.tensor
} }
// CHECK-LABEL: func @viewlike$basic_unsqueeze( // CHECK-LABEL: func.func @viewlike$basic_unsqueeze(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[UNSQUEEZE]] : !torch.vtensor // CHECK: return %[[UNSQUEEZE]] : !torch.vtensor
func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor %1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -156,13 +156,13 @@ func @viewlike$basic_unsqueeze(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor return %2 : !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$basic_flatten( // CHECK-LABEL: func.func @viewlike$basic_flatten(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INTM1:.*]] = torch.constant.int -1 // CHECK: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[ARG]], %[[INT0]], %[[INTM1]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor // CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[ARG]], %[[INT0]], %[[INTM1]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor
// CHECK: return %[[FLATTEN]] : !torch.vtensor // CHECK: return %[[FLATTEN]] : !torch.vtensor
func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor {
%start = torch.constant.int 0 %start = torch.constant.int 0
%end = torch.constant.int -1 %end = torch.constant.int -1
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
@ -171,13 +171,13 @@ func @viewlike$basic_flatten(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor return %2 : !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$transitive( // CHECK-LABEL: func.func @viewlike$transitive(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[UNSQUEEZE1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[UNSQUEEZE1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[UNSQUEEZE1]] : !torch.vtensor // CHECK: return %[[UNSQUEEZE1]] : !torch.vtensor
func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor %1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -186,14 +186,14 @@ func @viewlike$transitive(%arg0: !torch.vtensor) -> !torch.vtensor {
return %3 : !torch.vtensor return %3 : !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$transitive_tree( // CHECK-LABEL: func.func @viewlike$transitive_tree(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[UNSQUEEZE0:.*]] = torch.aten.unsqueeze %[[ARG]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET0:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[RET0:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[RET1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE0]], %[[INT0]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[RET0]], %[[RET1]] : !torch.vtensor, !torch.vtensor
func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { func.func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
// %1 has two users. // %1 has two users.
@ -208,11 +208,11 @@ func @viewlike$transitive_tree(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch
return %3, %5 : !torch.vtensor, !torch.vtensor return %3, %5 : !torch.vtensor, !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$unmodeled_op( // CHECK-LABEL: func.func @viewlike$unmodeled_op(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze {{.*}} : !torch.tensor, !torch.int -> !torch.tensor // CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze {{.*}} : !torch.tensor, !torch.int -> !torch.tensor
// CHECK: "some.op"(%[[UNSQUEEZE]]) : (!torch.tensor) -> () // CHECK: "some.op"(%[[UNSQUEEZE]]) : (!torch.tensor) -> ()
func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor %1 = torch.aten.unsqueeze %0, %int0 : !torch.tensor, !torch.int -> !torch.tensor
@ -221,23 +221,23 @@ func @viewlike$unmodeled_op(%arg0: !torch.vtensor) -> !torch.vtensor {
return %2 : !torch.vtensor return %2 : !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$two_inputs_one_copy( // CHECK-LABEL: func.func @viewlike$two_inputs_one_copy(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor // CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor // CHECK: return %[[EXPAND_AS]] : !torch.vtensor
func @viewlike$two_inputs_one_copy(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$two_inputs_one_copy(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.aten.expand_as %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor %1 = torch.aten.expand_as %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor
%2 = torch.copy.to_vtensor %1 : !torch.vtensor %2 = torch.copy.to_vtensor %1 : !torch.vtensor
return %2 : !torch.vtensor return %2 : !torch.vtensor
} }
// CHECK-LABEL: func @viewlike$two_inputs_two_copies( // CHECK-LABEL: func.func @viewlike$two_inputs_two_copies(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor // CHECK: %[[EXPAND_AS:.*]] = torch.aten.expand_as %[[ARG0]], %[[ARG1]] : !torch.vtensor, !torch.vtensor -> !torch.vtensor
// CHECK: return %[[EXPAND_AS]] : !torch.vtensor // CHECK: return %[[EXPAND_AS]] : !torch.vtensor
func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @viewlike$two_inputs_two_copies(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%1 = torch.copy.to_tensor %arg1 : !torch.tensor %1 = torch.copy.to_tensor %arg1 : !torch.tensor
%2 = torch.aten.expand_as %0, %1 : !torch.tensor, !torch.tensor -> !torch.tensor %2 = torch.aten.expand_as %0, %1 : !torch.tensor, !torch.tensor -> !torch.tensor

View File

@ -1,52 +1,52 @@
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s // RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @torch.operator( // CHECK-LABEL: func.func @torch.operator(
func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { func.func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
// CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor // CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
%0 = torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor %0 = torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> (!torch.LinearParams, !torch.LinearParams) { func.func @torch.linear_params.create(%arg0: !torch.tensor, %arg1: !torch.tensor) -> (!torch.LinearParams, !torch.LinearParams) {
%with_bias = torch.linear_params.create %arg0, %arg1 : !torch.tensor, !torch.tensor %with_bias = torch.linear_params.create %arg0, %arg1 : !torch.tensor, !torch.tensor
%without_bias = torch.linear_params.create %arg0 : !torch.tensor %without_bias = torch.linear_params.create %arg0 : !torch.tensor
return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams return %with_bias, %without_bias : !torch.LinearParams, !torch.LinearParams
} }
// CHECK: @tensor.default() -> !torch.tensor // CHECK: @tensor.default() -> !torch.tensor
func private @tensor.default() -> !torch.tensor func.func private @tensor.default() -> !torch.tensor
// CHECK: @tensor.default_explicit() -> !torch.tensor{{$}} // CHECK: @tensor.default_explicit() -> !torch.tensor{{$}}
func private @tensor.default_explicit() -> !torch.tensor<*,unk> func.func private @tensor.default_explicit() -> !torch.tensor<*,unk>
// CHECK: @tensor.value_semantic() -> !torch.vtensor{{$}} // CHECK: @tensor.value_semantic() -> !torch.vtensor{{$}}
func private @tensor.value_semantic() -> !torch.vtensor<*,unk> func.func private @tensor.value_semantic() -> !torch.vtensor<*,unk>
// CHECK: @tensor.dtype() -> !torch.tensor<*,si32> // CHECK: @tensor.dtype() -> !torch.tensor<*,si32>
func private @tensor.dtype() -> !torch.tensor<*,si32> func.func private @tensor.dtype() -> !torch.tensor<*,si32>
// CHECK: @tensor.ranked() -> !torch.tensor<[?,?,?],unk> // CHECK: @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
func private @tensor.ranked() -> !torch.tensor<[?,?,?],unk> func.func private @tensor.ranked() -> !torch.tensor<[?,?,?],unk>
// CHECK: @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk> // CHECK: @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk> func.func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> // CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> func.func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: @tuple.empty() -> !torch.tuple<> // CHECK: @tuple.empty() -> !torch.tuple<>
func private @tuple.empty() -> !torch.tuple<> func.func private @tuple.empty() -> !torch.tuple<>
// CHECK: @tuple.one_element() -> !torch.tuple<tensor> // CHECK: @tuple.one_element() -> !torch.tuple<tensor>
func private @tuple.one_element() -> !torch.tuple<tensor> func.func private @tuple.one_element() -> !torch.tuple<tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<tensor, tensor> // CHECK: @tuple.two_elements() -> !torch.tuple<tensor, tensor>
func private @tuple.two_elements() -> !torch.tuple<tensor, tensor> func.func private @tuple.two_elements() -> !torch.tuple<tensor, tensor>
// CHECK: @union.empty() -> !torch.union<> // CHECK: @union.empty() -> !torch.union<>
func private @union.empty() -> !torch.union<> func.func private @union.empty() -> !torch.union<>
// CHECK: @union.one_element() -> !torch.union<tensor> // CHECK: @union.one_element() -> !torch.union<tensor>
func private @union.one_element() -> !torch.union<tensor> func.func private @union.one_element() -> !torch.union<tensor>
// CHECK: @union.two_elements() -> !torch.union<tensor, tensor> // CHECK: @union.two_elements() -> !torch.union<tensor, tensor>
func private @union.two_elements() -> !torch.union<tensor, tensor> func.func private @union.two_elements() -> !torch.union<tensor, tensor>
// CHECK: @dict() -> !torch.dict<str, tensor> // CHECK: @dict() -> !torch.dict<str, tensor>
func private @dict() -> !torch.dict<str, tensor> func.func private @dict() -> !torch.dict<str, tensor>
// CHECK-LABEL: func @torch.tensor.literal() { // CHECK-LABEL: func.func @torch.tensor.literal() {
func @torch.tensor.literal() { func.func @torch.tensor.literal() {
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor // CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor %0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32> // CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
@ -54,19 +54,19 @@ func @torch.tensor.literal() {
return return
} }
// CHECK-LABEL: func @torch.vtensor.literal() { // CHECK-LABEL: func.func @torch.vtensor.literal() {
func @torch.vtensor.literal() { func.func @torch.vtensor.literal() {
// CHECK: torch.vtensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32> // CHECK: torch.vtensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
%0 = torch.vtensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32> %0 = torch.vtensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
return return
} }
func @derefine(%arg0: !torch.tensor) -> !torch.optional<tensor> { func.func @derefine(%arg0: !torch.tensor) -> !torch.optional<tensor> {
%0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<tensor> %0 = torch.derefine %arg0 : !torch.tensor to !torch.optional<tensor>
return %0 : !torch.optional<tensor> return %0 : !torch.optional<tensor>
} }
func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int { func.func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%0 = torch.prim.If %arg0 -> (!torch.int) { %0 = torch.prim.If %arg0 -> (!torch.int) {
%1 = torch.aten.add.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int %1 = torch.aten.add.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %1 : !torch.int torch.prim.If.yield %1 : !torch.int
@ -103,7 +103,7 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%none = torch.constant.none %none = torch.constant.none
// CHECK: %str = torch.constant.str "some str" // CHECK: %str = torch.constant.str "some str"
%str = torch.constant.str "some str" %str = torch.constant.str "some str"
func private @f(%arg0: !torch.nn.Module<"test">) { func.func private @f(%arg0: !torch.nn.Module<"test">) {
return return
} }
@ -131,7 +131,7 @@ torch.nn_module {
} : !torch.nn.Module<"test"> } : !torch.nn.Module<"test">
func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
torch.shape.calculate.yield %0 : !torch.vtensor torch.shape.calculate.yield %0 : !torch.vtensor
@ -142,7 +142,7 @@ func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) { func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor %0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
return return
} }

View File

@ -6,23 +6,23 @@ torch.class_type @c {
} }
// CHECK-LABEL: func private @test_call_method( // CHECK-LABEL: func.func private @test_call_method(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">, // CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[RET:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float // CHECK: %[[RET:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK: return %[[RET]] : !torch.float // CHECK: return %[[RET]] : !torch.float
func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float { func.func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = torch.prim.CallMethod %arg0["test_call_method"] (%arg1) : !torch.nn.Module<"c">, (!torch.float) -> !torch.float %0 = torch.prim.CallMethod %arg0["test_call_method"] (%arg1) : !torch.nn.Module<"c">, (!torch.float) -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func private @test_call_indirect( // CHECK-LABEL: func.func private @test_call_indirect(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">, // CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// Ensure no func.constant. // Ensure no func.constant.
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float // CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK-NEXT: return %[[VAL_2]] : !torch.float // CHECK-NEXT: return %[[VAL_2]] : !torch.float
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float { func.func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = constant @test_call_method : (!torch.nn.Module<"c">, !torch.float) -> !torch.float %0 = constant @test_call_method : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
%1 = call_indirect %0(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float %1 = call_indirect %0(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
return %1 : !torch.float return %1 : !torch.float

View File

@ -2,7 +2,7 @@
// ----- // -----
func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !torch.tensor { func.func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !torch.tensor {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
// expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}} // expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}}
%ret = torch.aten.cat %list, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor %ret = torch.aten.cat %list, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
@ -11,7 +11,7 @@ func @convert_to_value_semantic_tensors_list( %list: !torch.list<tensor>) -> !to
// ----- // -----
func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<tensor>, func.func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<tensor>,
%t: !torch.tensor, %t: !torch.tensor,
%training: !torch.bool, %training: !torch.bool,
%cudnn_enable: !torch.bool, %cudnn_enable: !torch.bool,

View File

@ -1,17 +1,17 @@
// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s // RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s
// CHECK-LABEL: func @convert_to_value_semantic_tensors( // CHECK-LABEL: func.func @convert_to_value_semantic_tensors(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32> // CHECK: %[[OPERAND_TENSOR:.*]] = torch.copy.to_vtensor %[[ARG]] : !torch.vtensor<[],f32>
// CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: %[[RESULT_TENSOR:.*]] = torch.aten.tanh %[[OPERAND_TENSOR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32> // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[RESULT_TENSOR]] : !torch.tensor<[],f32>
// CHECK: return %[[RET]] : !torch.tensor<[],f32> // CHECK: return %[[RET]] : !torch.tensor<[],f32>
func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> %0 = torch.aten.tanh %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32> return %0 : !torch.tensor<[],f32>
} }
// CHECK-LABEL: func @convert_to_value_semantic_tensors_list( // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list(
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor, // CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor { // CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T0:.*]] = torch.copy.to_tensor %[[VT0]] : !torch.tensor // CHECK: %[[T0:.*]] = torch.copy.to_tensor %[[VT0]] : !torch.tensor
@ -30,7 +30,7 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
// CHECK-SAME: !torch.list<vtensor>, !torch.int -> !torch.vtensor // CHECK-SAME: !torch.list<vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor { func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor {
%t0 = torch.copy.to_tensor %vt0 : !torch.tensor %t0 = torch.copy.to_tensor %vt0 : !torch.tensor
%t1 = torch.copy.to_tensor %vt1 : !torch.tensor %t1 = torch.copy.to_tensor %vt1 : !torch.tensor
%t2 = torch.copy.to_tensor %vt2 : !torch.tensor %t2 = torch.copy.to_tensor %vt2 : !torch.tensor
@ -40,7 +40,7 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional( // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool, // CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[FLOAT:.*]]: !torch.float) -> !torch.tensor {
@ -67,7 +67,7 @@ func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
// CHECK: } // CHECK: }
func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor, func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
%ft: !torch.tensor<[4],f32>, %ft: !torch.tensor<[4],f32>,
%training: !torch.bool, %training: !torch.bool,
%cudnn_enable: !torch.bool, %cudnn_enable: !torch.bool,
@ -83,7 +83,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
return %ret: !torch.tensor return %ret: !torch.tensor
} }
// CHECK-LABEL: func @reduce_trailing_underscore_inplace_variant( // CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { // CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
// CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1:.*]] = torch.constant.int 1
@ -97,23 +97,23 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
%c1 = torch.constant.int 1 %c1 = torch.constant.int 1
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32> %0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
} }
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor { // CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32> // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor // CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
// CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor // CHECK: %[[TENSOR:.*]] = torch.copy.to_tensor %[[SIZES_ERASED]] : !torch.tensor
// CHECK: return %[[TENSOR]] : !torch.tensor // CHECK: return %[[TENSOR]] : !torch.tensor
func @torch.tensor.literal() -> !torch.tensor { func.func @torch.tensor.literal() -> !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor %0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional_list( // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>, // CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>,
// CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { // CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDICES]] : // CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDICES]] :
@ -124,13 +124,13 @@ func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<vtensor<[2,3],si64>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor { func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>> %tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<optional<tensor<[2,3],si64>>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor<[2,3],si64>>> -> !torch.tensor %ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<optional<tensor<[2,3],si64>>> -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten.uniform_( // CHECK-LABEL: func.func @torch.aten.uniform_(
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float, // CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor { // CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor // CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
@ -140,12 +140,12 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor { func.func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor %ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten.bernoulli_.float( // CHECK-LABEL: func.func @torch.aten.bernoulli_.float(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[GENERATOR:.*]] = torch.constant.none // CHECK: %[[GENERATOR:.*]] = torch.constant.none
// CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[P:.*]] = torch.constant.float 5.000000e-01
@ -155,14 +155,14 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none %generator = torch.constant.none
%p = torch.constant.float 5.000000e-01 %p = torch.constant.float 5.000000e-01
%ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor %ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten.fill_.Scalar( // CHECK-LABEL: func.func @torch.aten.fill_.Scalar(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VALUE:.*]] = torch.constant.int 1 // CHECK: %[[VALUE:.*]] = torch.constant.int 1
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor // CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
@ -171,13 +171,13 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor { func.func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%value = torch.constant.int 1 %value = torch.constant.int 1
%ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor %ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten._index_put_impl_( // CHECK-LABEL: func.func @torch.aten._index_put_impl_(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -191,7 +191,7 @@ func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[SELF:.*]] : !torch.tensor // CHECK: return %[[SELF:.*]] : !torch.tensor
func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor { func.func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%false = torch.constant.bool false %false = torch.constant.bool false
%indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list<optional<tensor>> %indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list<optional<tensor>>
@ -199,7 +199,7 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// CHECK-LABEL: func @torch.aten.copy_( // CHECK-LABEL: func.func @torch.aten.copy_(
// CHECK-SAME: %[[DST:.*]]: !torch.tensor, // CHECK-SAME: %[[DST:.*]]: !torch.tensor,
// CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[SRC:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -210,7 +210,7 @@ func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[DST]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[DST]] : !torch.tensor // CHECK: return %[[DST]] : !torch.tensor
func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor { func.func @torch.aten.copy_(%dst: !torch.tensor, %src : !torch.tensor) -> !torch.tensor {
%false = torch.constant.bool false %false = torch.constant.bool false
%ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor %ret = torch.aten.copy_ %dst, %src, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor

View File

@ -1,34 +1,34 @@
// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s // RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-refine-public-return | FileCheck %s
// CHECK-LABEL: func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[2,3,?],f32> // CHECK: return %[[ARG]] : !torch.vtensor<[2,3,?],f32>
func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32> %1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor %2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// CHECK-LABEL: func @multiple_use_non_value_tensor( // CHECK-LABEL: func.func @multiple_use_non_value_tensor(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[NON_VALUE_TENSOR:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor // CHECK: %[[NON_VALUE_TENSOR:.*]] = torch.copy.to_tensor %[[ARG0]] : !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[ARG1]] overwrites %[[NON_VALUE_TENSOR]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[ARG1]] overwrites %[[NON_VALUE_TENSOR]] : !torch.vtensor, !torch.tensor
// CHECK: %[[RESULT:.*]] = torch.copy.to_vtensor %[[NON_VALUE_TENSOR]] : !torch.vtensor // CHECK: %[[RESULT:.*]] = torch.copy.to_vtensor %[[NON_VALUE_TENSOR]] : !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func @multiple_use_non_value_tensor(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor { func.func @multiple_use_non_value_tensor(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.tensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// No conversion on private function. // No conversion on private function.
// CHECK-LABEL: func private @basic_private( // CHECK-LABEL: func.func private @basic_private(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32> // CHECK: %[[COPIED:.*]] = torch.copy.to_tensor %[[ARG]] : !torch.tensor<[2,3,?],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor // CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[COPIED]] : !torch.tensor<[2,3,?],f32> to !torch.tensor
// CHECK: return %[[CASTED]] : !torch.tensor // CHECK: return %[[CASTED]] : !torch.tensor
func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
%1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32> %1 = torch.copy.to_tensor %arg0 : !torch.tensor<[2,3,?],f32>
%2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor %2 = torch.tensor_static_info_cast %1 : !torch.tensor<[2,3,?],f32> to !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
@ -38,11 +38,11 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
// Call to public function. // Call to public function.
// expected-error @+1 {{unimplemented}} // expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %arg0 : tensor<*xf32> return %arg0 : tensor<*xf32>
} }
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> { func.func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
@ -51,7 +51,7 @@ func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// Multiple returns. // Multiple returns.
// expected-error @+1 {{unimplemented}} // expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%ctrue = arith.constant true %ctrue = arith.constant true
cf.cond_br %ctrue, ^bb1, ^bb2 cf.cond_br %ctrue, ^bb1, ^bb2
^bb1: ^bb1:

View File

@ -2,7 +2,7 @@
// ----- // -----
// CHECK-LABEL: func @prim.if$branch_merge_type_tensor( // CHECK-LABEL: func.func @prim.if$branch_merge_type_tensor(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool, // CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T1:.*]]: !torch.tensor, // CHECK-SAME: %[[T1:.*]]: !torch.tensor,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool { // CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool {
@ -18,7 +18,7 @@
// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool // CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool
// CHECK: return %[[RET]] : !torch.bool // CHECK: return %[[RET]] : !torch.bool
func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool { func.func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) { %res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor> %optional0 = torch.derefine %t0: !torch.tensor to !torch.optional<tensor>
torch.prim.If.yield %optional0: !torch.optional<tensor> torch.prim.If.yield %optional0: !torch.optional<tensor>
@ -33,7 +33,7 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// ----- // -----
// CHECK-LABEL: func @prim.if$branch_merge_type_optional( // CHECK-LABEL: func.func @prim.if$branch_merge_type_optional(
// CHECK-SAME: %[[PRED:.*]]: !torch.bool, // CHECK-SAME: %[[PRED:.*]]: !torch.bool,
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> { // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional<tensor> {
// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) { // CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional<tensor>) {
@ -46,7 +46,7 @@ func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %
// CHECK: } // CHECK: }
// CHECK: return %[[MERGED:.*]] : !torch.optional<tensor> // CHECK: return %[[MERGED:.*]] : !torch.optional<tensor>
func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> { func.func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional<tensor> {
%res = torch.prim.If %pred -> (!torch.optional<tensor>) { %res = torch.prim.If %pred -> (!torch.optional<tensor>) {
%none = torch.constant.none %none = torch.constant.none
%optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor> %optional0 = torch.derefine %none: !torch.none to !torch.optional<tensor>
@ -60,7 +60,7 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// ----- // -----
// CHECK-LABEL: func @prim.if$refined_type_conflicting( // CHECK-LABEL: func.func @prim.if$refined_type_conflicting(
// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor { // CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor> // CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool // CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool
@ -73,7 +73,7 @@ func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor)
// CHECK: } // CHECK: }
// CHECK: return %[[PRED:.*]] : !torch.tensor // CHECK: return %[[PRED:.*]] : !torch.tensor
func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor { func.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
%optional = torch.derefine %none: !torch.none to !torch.optional<tensor> %optional = torch.derefine %none: !torch.none to !torch.optional<tensor>
%pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool %pred = torch.aten.__isnot__ %optional, %none : !torch.optional<tensor>, !torch.none -> !torch.bool
%res = torch.prim.If %pred -> (!torch.tensor) { %res = torch.prim.If %pred -> (!torch.tensor) {
@ -88,7 +88,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// ----- // -----
// CHECK-LABEL: func @prim.loop$region_arg_to_internal( // CHECK-LABEL: func.func @prim.loop$region_arg_to_internal(
// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> { // CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional<tensor> {
// CHECK: %[[INT10:.*]] = torch.constant.int 10 // CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[INDV:.*]] = torch.constant.int 0 // CHECK: %[[INDV:.*]] = torch.constant.int 0
@ -105,7 +105,7 @@ func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor {
// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor> // CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<tensor>
// CHECK: return %[[OPTIONAL]] : !torch.optional<tensor> // CHECK: return %[[OPTIONAL]] : !torch.optional<tensor>
func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> { func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<tensor> {
%int10 = torch.constant.int 10 %int10 = torch.constant.int 10
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%true = torch.constant.bool true %true = torch.constant.bool true
@ -120,11 +120,11 @@ func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional<te
// ----- // -----
// CHECK-LABEL: func @f // CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> // CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor) cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor): ^bb1(%arg1: !torch.vtensor):
@ -134,16 +134,16 @@ func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
// ----- // -----
// CHECK-LABEL: func @f // CHECK-LABEL: func.func @f
// CHECK: func private @callee // CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> // CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func @f() { func.func @f() {
builtin.module { builtin.module {
func private @callee(%arg0: !torch.vtensor) { func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return return
} }
func @caller(%arg0: !torch.vtensor<*,f32>) { func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> () call @callee(%cast) : (!torch.vtensor) -> ()
return return

View File

@ -4,7 +4,7 @@
// function (i.e. new code called from visitOperation). // function (i.e. new code called from visitOperation).
// ----- // -----
// CHECK-LABEL: func @aten.arange.start$int64_dtype( // CHECK-LABEL: func.func @aten.arange.start$int64_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.int, // CHECK-SAME: %[[START:.*]]: !torch.int,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -14,14 +14,14 @@
// CHECK-SAME: -> !torch.vtensor<*,si64> // CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor // CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor return %ret : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @aten.arange.start$float32_dtype( // CHECK-LABEL: func.func @aten.arange.start$float32_dtype(
// CHECK-SAME: %[[START:.*]]: !torch.float, // CHECK-SAME: %[[START:.*]]: !torch.float,
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -31,14 +31,14 @@ func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !to
// CHECK-SAME: -> !torch.vtensor<*,f32> // CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor // CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret : !torch.vtensor return %ret : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @aten.arange.start$specified_dtype( // CHECK-LABEL: func.func @aten.arange.start$specified_dtype(
// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[CST6:.*]] = torch.constant.int 6 // CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -48,7 +48,7 @@ func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) ->
// CHECK-SAME: -> !torch.vtensor<*,f32> // CHECK-SAME: -> !torch.vtensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RET]] : !torch.vtensor // CHECK: return %[[RET]] : !torch.vtensor
func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
%int6 = torch.constant.int 6 %int6 = torch.constant.int 6
%none = torch.constant.none %none = torch.constant.none
%ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor
@ -56,20 +56,20 @@ func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.linear( // CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32> // CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor { func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @aten.sum.dim_IntList( // CHECK-LABEL: func.func @aten.sum.dim_IntList(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor { // CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -82,7 +82,7 @@ func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<
// CHECK-SAME: -> !torch.vtensor<*,si64> // CHECK-SAME: -> !torch.vtensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
%false = torch.constant.bool false %false = torch.constant.bool false
%none = torch.constant.none %none = torch.constant.none
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -93,14 +93,14 @@ func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor {
} }
// ----- // -----
// CHECK-LABEL: func @aten.any.dim( // CHECK-LABEL: func.func @aten.any.dim(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { // CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 // CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1
// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1> // CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%false = torch.constant.bool false %false = torch.constant.bool false
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
%ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor %ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor
@ -108,18 +108,18 @@ func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
} }
// ----- // -----
// CHECK-LABEL: func @aten.any( // CHECK-LABEL: func.func @aten.any(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { // CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1> // CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
%ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor %ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor
return %ret : !torch.vtensor return %ret : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.zeros( // CHECK-LABEL: func.func @torch.aten.zeros(
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -127,7 +127,7 @@ func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor {
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32> // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor { func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int> %sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
@ -136,19 +136,19 @@ func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.type_as( // CHECK-LABEL: func.func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { // CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32> // CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor { func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor %ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
return %ret: !torch.tensor return %ret: !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.cat( // CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, // CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -156,7 +156,7 @@ func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32> // CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor { func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor> %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor %ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
@ -164,29 +164,29 @@ func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor( // CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64> // CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor { func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor %ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape( // CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64> // CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor { func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor %ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.embedding( // CHECK-LABEL: func.func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { // CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -194,7 +194,7 @@ func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) ->
// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32> // CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor { func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor {
%false = torch.constant.bool false %false = torch.constant.bool false
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor %ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor
@ -202,14 +202,14 @@ func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !tor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.tensor.float( // CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32> // CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor %ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor
@ -217,7 +217,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype( // CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST11:.*]] = torch.constant.int 11 // CHECK: %[[CST11:.*]] = torch.constant.int 11
@ -225,7 +225,7 @@ func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1> // CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%int11 = torch.constant.int 11 %int11 = torch.constant.int 11
%false = torch.constant.bool false %false = torch.constant.bool false
@ -234,59 +234,59 @@ func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int( // CHECK-LABEL: func.func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> // CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype( // CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 4 // CHECK: %[[DTYPE:.*]] = torch.constant.int 4
// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64> // CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor // CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor // CHECK: return %[[RET]] : !torch.tensor
func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor {
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix( // CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32> // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector( // CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor { // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32> // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor { func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.to.dtype( // CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype // CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : // CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
@ -294,7 +294,7 @@ func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !t
// CHECK-SAME: -> !torch.tensor<*,si64> // CHECK-SAME: -> !torch.tensor<*,si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor // CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor // CHECK-NEXT: return %[[RES]] : !torch.tensor
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
@ -303,18 +303,18 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
} }
// ----- // -----
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar( // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64> // CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor { func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor return %0: !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.tensor( // CHECK-LABEL: func.func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor { // CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -323,7 +323,7 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
// CHECK-SAME: -> !torch.tensor<*,f32> // CHECK-SAME: -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor { func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%false = torch.constant.bool false %false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor %ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
@ -331,7 +331,7 @@ func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
} }
// ----- // -----
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype( // CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor { // CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT4:.*]] = torch.constant.int 4
@ -339,7 +339,7 @@ func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64> // CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor { func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none %none = torch.constant.none
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%false = torch.constant.bool false %false = torch.constant.bool false

View File

@ -7,35 +7,35 @@
// should go in refine-types-ops.mlir. // should go in refine-types-ops.mlir.
// ----- // -----
// CHECK-LABEL: func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @keep_existing_shape_information( // CHECK-LABEL: func.func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
// CHECK: return %[[TANH]] : !torch.vtensor<[2],f32> // CHECK: return %[[TANH]] : !torch.vtensor<[2],f32>
func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32> %1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
return %1 : !torch.vtensor<[2],f32> return %1 : !torch.vtensor<[2],f32>
} }
// ----- // -----
// CHECK-LABEL: func @propagate_through_multiple_ops( // CHECK-LABEL: func.func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH2:.*]] = torch.aten.tanh %[[TANH1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH2:.*]] = torch.aten.tanh %[[TANH1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH3:.*]] = torch.tensor_static_info_cast %[[TANH2]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[TANH3:.*]] = torch.tensor_static_info_cast %[[TANH2]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[TANH3]] : !torch.vtensor // CHECK: return %[[TANH3]] : !torch.vtensor
func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor %2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
%3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor %3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor
@ -45,108 +45,108 @@ func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vte
// ----- // -----
// Check rewriting logic in case of mixes of users that do/don't allow type // Check rewriting logic in case of mixes of users that do/don't allow type
// refinement. // refinement.
// CHECK-LABEL: func @mixed_allowing_not_allowing_type_refinement( // CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH0]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH0]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor %3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
return %1, %1 : !torch.vtensor, !torch.vtensor return %1, %1 : !torch.vtensor, !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion$same_category_different_width( // CHECK-LABEL: func.func @type_promotion$same_category_different_width(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3 // CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],si64> // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],si64> to !torch.vtensor<[?],unk> // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],si64> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> { func.func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],unk> %0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion$different_category( // CHECK-LABEL: func.func @type_promotion$different_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3 // CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk> // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> { func.func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],unk> %0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion$same_category_zero_rank_wider( // CHECK-LABEL: func.func @type_promotion$same_category_zero_rank_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00 // CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],f32> // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk> // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> { func.func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00 %float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],unk> %0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion$zero_rank_higher_category( // CHECK-LABEL: func.func @type_promotion$zero_rank_higher_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 2 // CHECK: %[[ALPHA:.*]] = torch.constant.int 2
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32> // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk> // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { func.func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],unk> %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion$alpha_wider( // CHECK-LABEL: func.func @type_promotion$alpha_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00 // CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],f32> // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk> // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk> // CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> { func.func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00 %float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],unk> %0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk> return %0 : !torch.vtensor<[?],unk>
} }
// ----- // -----
// CHECK-LABEL: func @type_promotion_scalar_operation( // CHECK-LABEL: func.func @type_promotion_scalar_operation(
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float, // CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { // CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float // CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number // CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
// CHECK: return %[[RET]] : !torch.number // CHECK: return %[[RET]] : !torch.number
func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number { func.func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number %ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
return %ret : !torch.number return %ret : !torch.number
} }
// ----- // -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static( // CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32> // CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> // CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
@ -157,14 +157,14 @@ func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.
} }
// ----- // -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic( // CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32> // CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32>
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> // CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32>
// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32> // CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32>
// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> // CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32>
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
@ -175,23 +175,23 @@ func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.
} }
// ----- // -----
// CHECK-LABEL: func @bf16_result_type( // CHECK-LABEL: func.func @bf16_result_type(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> // CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16>
// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16> // CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16>
func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> {
%1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16>
return %1 : !torch.vtensor<[2],bf16> return %1 : !torch.vtensor<[2],bf16>
} }
// ----- // -----
// CHECK-LABEL: func @propagate_scalar_type( // CHECK-LABEL: func.func @propagate_scalar_type(
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { // CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number // CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number
// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int // CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int
// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number // CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number
// CHECK: return %[[RET]] : !torch.number // CHECK: return %[[RET]] : !torch.number
func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
%num = torch.derefine %arg0 : !torch.int to !torch.number %num = torch.derefine %arg0 : !torch.int to !torch.number
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
return %1 : !torch.number return %1 : !torch.number

View File

@ -1,20 +1,20 @@
// RUN: torch-mlir-opt -torch-reify-shape-calculations -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-reify-shape-calculations -split-input-file %s | FileCheck %s
// CHECK: module { // CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.tanh( // CHECK: func.func private @__torch_mlir_shape_fn.aten.tanh(
// CHECK-LABEL: func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor // CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<int>) -> !torch.list<int> // CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.tanh(%[[SHAPE]]) : (!torch.list<int>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @basic(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
@ -22,9 +22,9 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// ----- // -----
// CHECK: module { // CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.fill.Scalar( // CHECK: func.func private @__torch_mlir_shape_fn.aten.fill.Scalar(
// CHECK-LABEL: func @valsem_ops( // CHECK-LABEL: func.func @valsem_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
@ -32,11 +32,11 @@ func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
// CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor // CHECK: torch.shape.calculate.yield %[[VALUE]] : !torch.vtensor
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<int>, !torch.float) -> !torch.list<int> // CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.fill.Scalar(%[[SHAPE]], %{{.*}}) : (!torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { func.func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor %0 = torch.valsem.aten.fill.Scalar %arg0, %arg1 : !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
@ -44,10 +44,10 @@ func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// ----- // -----
// CHECK: module { // CHECK: module {
// CHECK-LABEL: func private @__torch_mlir_shape_fn.aten.uniform( // CHECK-LABEL: func.func private @__torch_mlir_shape_fn.aten.uniform(
// CHECK-SAME: {{.*}}!torch.any) // CHECK-SAME: {{.*}}!torch.any)
// CHECK-LABEL: func @adjust_shape_function_arg$torch.any( // CHECK-LABEL: func.func @adjust_shape_function_arg$torch.any(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.float) -> !torch.vtensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
@ -57,11 +57,11 @@ func @valsem_ops(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ANY:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.any // CHECK: %[[ANY:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.any
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<int>, !torch.float, !torch.float, !torch.any) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.uniform(%[[ARG0_SHAPE]], %[[ARG1]], %[[ARG1]], %[[ANY]]) : (!torch.list<int>, !torch.float, !torch.float, !torch.any) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor { func.func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.float) -> !torch.vtensor {
%none = torch.constant.none %none = torch.constant.none
%0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor %0 = torch.valsem.aten.uniform %arg0, %arg1, %arg1, %none : !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
@ -74,9 +74,9 @@ func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.f
// callees of the shape functions. // callees of the shape functions.
// CHECK: module { // CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.add.Tensor( // CHECK: func.func private @__torch_mlir_shape_fn.aten.add.Tensor(
// CHECK-LABEL: func @adjust_shape_function_arg$scalar( // CHECK-LABEL: func.func @adjust_shape_function_arg$scalar(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -87,11 +87,11 @@ func @adjust_shape_function_arg$torch.any(%arg0: !torch.vtensor, %arg1: !torch.f
// CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[ARG0_SHAPE:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[ARG1_SHAPE:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[ARG1_SHAPE:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SCALAR_CONVERTED:.*]] = torch.aten.Float.Scalar %[[INT1]] : !torch.int -> !torch.float // CHECK: %[[SCALAR_CONVERTED:.*]] = torch.aten.Float.Scalar %[[INT1]] : !torch.int -> !torch.float
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int> // CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.add.Tensor(%[[ARG0_SHAPE]], %[[ARG1_SHAPE]], %[[SCALAR_CONVERTED]]) : (!torch.list<int>, !torch.list<int>, !torch.float) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor
return %0 : !torch.vtensor return %0 : !torch.vtensor
@ -100,9 +100,9 @@ func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vten
// ----- // -----
// CHECK: module { // CHECK: module {
// CHECK: func private @__torch_mlir_shape_fn.aten.topk( // CHECK: func.func private @__torch_mlir_shape_fn.aten.topk(
// CHECK-LABEL: func @multiple_results( // CHECK-LABEL: func.func @multiple_results(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> (!torch.tensor, !torch.tensor) { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
@ -112,13 +112,13 @@ func @adjust_shape_function_arg$scalar(%arg0: !torch.vtensor, %arg1: !torch.vten
// CHECK: torch.shape.calculate.yield %[[TOP_VALUES]], %[[TOPK_INDICES]] : !torch.tensor, !torch.tensor // CHECK: torch.shape.calculate.yield %[[TOP_VALUES]], %[[TOPK_INDICES]] : !torch.tensor, !torch.tensor
// CHECK: } shapes { // CHECK: } shapes {
// CHECK: %[[ARG_SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.tensor -> !torch.list<int> // CHECK: %[[ARG_SHAPE:.*]] = torch.aten.size %[[ARG]] : !torch.tensor -> !torch.list<int>
// CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<list<int>, list<int>> // CHECK: %[[TOPK_SHAPE_TUPLE:.*]] = func.call @__torch_mlir_shape_fn.aten.topk(%[[ARG_SHAPE]], %[[INT3]], %[[INT1]], %[[TRUE]], %[[TRUE]]) : (!torch.list<int>, !torch.int, !torch.int, !torch.bool, !torch.bool) -> !torch.tuple<list<int>, list<int>>
// CHECK: %[[TOPK_SHAPE:.*]]:2 = torch.prim.TupleUnpack %[[TOPK_SHAPE_TUPLE]] : !torch.tuple<list<int>, list<int>> -> !torch.list<int>, !torch.list<int> // CHECK: %[[TOPK_SHAPE:.*]]:2 = torch.prim.TupleUnpack %[[TOPK_SHAPE_TUPLE]] : !torch.tuple<list<int>, list<int>> -> !torch.list<int>, !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[TOPK_SHAPE]]#0, %[[TOPK_SHAPE]]#1 : !torch.list<int>, !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[TOPK_SHAPE]]#0, %[[TOPK_SHAPE]]#1 : !torch.list<int>, !torch.list<int>
// CHECK: } : !torch.tensor, !torch.tensor // CHECK: } : !torch.tensor, !torch.tensor
// CHECK: return %[[RESULTS:.*]]#0, %[[RESULTS]]#1 : !torch.tensor, !torch.tensor // CHECK: return %[[RESULTS:.*]]#0, %[[RESULTS]]#1 : !torch.tensor, !torch.tensor
func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) { func.func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
%true = torch.constant.bool true %true = torch.constant.bool true
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
@ -128,7 +128,7 @@ func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// ----- // -----
// CHECK-LABEL: func @adjust_shape_function_arg$optional( // CHECK-LABEL: func.func @adjust_shape_function_arg$optional(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
@ -138,11 +138,11 @@ func @multiple_results(%arg0: !torch.tensor) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[SHAPE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SHAPE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[SHAPE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int> // CHECK: %[[SHAPE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[DEREFINED:.*]] = torch.derefine %{{.*}} : !torch.none to !torch.optional<list<int>> // CHECK: %[[DEREFINED:.*]] = torch.derefine %{{.*}} : !torch.none to !torch.optional<list<int>>
// CHECK: %[[SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.conv2d(%[[SHAPE0]], %[[SHAPE1]], %[[DEREFINED]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int4 = torch.constant.int 4 %int4 = torch.constant.int 4
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -157,7 +157,7 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
// ----- // -----
// CHECK-LABEL: func @adjust_shape_function_arg$optional_tensor( // CHECK-LABEL: func.func @adjust_shape_function_arg$optional_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[TRUE:.*]] = torch.constant.bool true
@ -184,11 +184,11 @@ func @adjust_shape_function_arg$optional(%arg0: !torch.vtensor, %arg1: !torch.vt
// CHECK: %[[DEREFINED_NONE1:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>> // CHECK: %[[DEREFINED_NONE1:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE2:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>> // CHECK: %[[DEREFINED_NONE2:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[DEREFINED_NONE3:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>> // CHECK: %[[DEREFINED_NONE3:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<list<int>>
// CHECK: %[[BN_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int> // CHECK: %[[BN_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.batch_norm(%[[ARG_SIZE]], %[[DEREFINED_OPTIONAL_SIZE:.*]], %[[DEREFINED_NONE1]], %[[DEREFINED_NONE2]], %[[DEREFINED_NONE3]], %[[FALSE]], %[[C1EM1]], %[[C1EM5]], %[[TRUE]]) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[BN_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[BN_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch.vtensor {
%false = torch.constant.bool false %false = torch.constant.bool false
%true = torch.constant.bool true %true = torch.constant.bool true
%float1.000000e-05 = torch.constant.float 1.000000e-05 %float1.000000e-05 = torch.constant.float 1.000000e-05
@ -201,7 +201,7 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
// ----- // -----
// CHECK-LABEL: func @adjust_shape_function_arg$list( // CHECK-LABEL: func.func @adjust_shape_function_arg$list(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor) -> !torch.list<vtensor> // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor) -> !torch.list<vtensor>
@ -221,11 +221,11 @@ func @adjust_shape_function_arg$optional_tensor(%arg0: !torch.vtensor) -> !torch
// CHECK: %{{.*}} = torch.aten.append.t %[[ADJUSTED_LIST]], %[[ADJUSTED_ELEMENT]] : !torch.list<optional<list<int>>>, !torch.optional<list<int>> -> !torch.list<optional<list<int>>> // CHECK: %{{.*}} = torch.aten.append.t %[[ADJUSTED_LIST]], %[[ADJUSTED_ELEMENT]] : !torch.list<optional<list<int>>>, !torch.optional<list<int>> -> !torch.list<optional<list<int>>>
// CHECK: torch.prim.Loop.condition %[[CTRUE]], iter() // CHECK: torch.prim.Loop.condition %[[CTRUE]], iter()
// CHECK: } : (!torch.int, !torch.bool) -> () // CHECK: } : (!torch.int, !torch.bool) -> ()
// CHECK: %[[RESULT_SHAPE:.*]] = call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int> // CHECK: %[[RESULT_SHAPE:.*]] = func.call @__torch_mlir_shape_fn.aten.index.Tensor(%[[ARG0_SHAPE]], %[[ADJUSTED_LIST]]) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[RESULT_SHAPE]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[VAL_15:.*]] : !torch.vtensor // CHECK: return %[[VAL_15:.*]] : !torch.vtensor
func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor) -> !torch.list<vtensor> %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor %1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-simplify-shape-calculations -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-simplify-shape-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @refine_shape_calculate_result$basic( // CHECK-LABEL: func.func @refine_shape_calculate_result$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -14,7 +14,7 @@
// CHECK: } : !torch.vtensor<[2,?],unk> // CHECK: } : !torch.vtensor<[2,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { func.func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor torch.shape.calculate.yield %arg0 : !torch.vtensor
@ -25,10 +25,10 @@ func @refine_shape_calculate_result$basic(%arg0: !torch.vtensor, %arg1: !torch.i
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_one_element( // CHECK-LABEL: func.func @refine_shape_calculate_result$clobber_one_element(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,2],unk> to !torch.vtensor // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,2],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor { func.func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
@ -47,10 +47,10 @@ func @refine_shape_calculate_result$clobber_one_element(%arg0: !torch.vtensor, %
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @refine_shape_calculate_result$clobber_all_elements( // CHECK-LABEL: func.func @refine_shape_calculate_result$clobber_all_elements(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,?],unk> to !torch.vtensor // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor { func.func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
@ -71,10 +71,10 @@ func @refine_shape_calculate_result$clobber_all_elements(%arg0: !torch.vtensor,
} }
// Make sure that information previously in the IR is not lost. // Make sure that information previously in the IR is not lost.
// CHECK-LABEL: func @refine_shape_calculate_result$meet_with_existing_information( // CHECK-LABEL: func.func @refine_shape_calculate_result$meet_with_existing_information(
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,3],f32> // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,3],f32>
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor<[?,3],f32> // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor<[?,3],f32>
func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> { func.func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,3],f32> {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
@ -87,9 +87,9 @@ func @refine_shape_calculate_result$meet_with_existing_information(%arg0: !torch
} }
// Don't insert static info casts if not needed. // Don't insert static info casts if not needed.
// CHECK-LABEL: func @refine_shape_calculate_result$user_allows_type_refinement( // CHECK-LABEL: func.func @refine_shape_calculate_result$user_allows_type_refinement(
// CHECK-NOT: torch.tensor_static_info_cast // CHECK-NOT: torch.tensor_static_info_cast
func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vtensor) -> !torch.vtensor {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
%1 = torch.shape.calculate { %1 = torch.shape.calculate {
@ -102,7 +102,7 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
return %2 : !torch.vtensor return %2 : !torch.vtensor
} }
// CHECK-LABEL: func @fully_unroll_prim_loop$unroll( // CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -117,7 +117,7 @@ func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !torch.vt
// CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[ARG1]] : !torch.list<int>
// CHECK: } : !torch.vtensor // CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor
func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor { func.func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
@ -134,9 +134,9 @@ func @fully_unroll_prim_loop$unroll(%arg0: !torch.vtensor, %arg1: !torch.list<in
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @fully_unroll_prim_loop$no_unroll( // CHECK-LABEL: func.func @fully_unroll_prim_loop$no_unroll(
// CHECK: torch.prim.Loop // CHECK: torch.prim.Loop
func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor { func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
@ -152,13 +152,13 @@ func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch.list
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @abstractly_interpret_list_ops$basic( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int, // CHECK-SAME: %[[ARG1:.*]]: !torch.int,
// CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[ARG2:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes { } shapes {
@ -171,10 +171,10 @@ func @abstractly_interpret_list_ops$basic(%arg0: !torch.vtensor, %arg1: !torch.i
} }
// Test the different supported mutation ops. // Test the different supported mutation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_ops( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$mutation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int1, %arg1, %arg2, %arg3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
@ -192,10 +192,10 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !
} }
// Test negative indexes with set_item op. // Test negative indexes with set_item op.
// CHECK-LABEL: func @abstractly_interpret_list_ops$neg_index_set_item( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$neg_index_set_item(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int-1 = torch.constant.int -1 %int-1 = torch.constant.int -1
%int-2 = torch.constant.int -2 %int-2 = torch.constant.int -2
@ -211,10 +211,10 @@ func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %a
} }
// Test interspersed mutation and evaluation ops. // Test interspersed mutation and evaluation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torch.vtensor) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes { } shapes {
@ -230,10 +230,10 @@ func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(%arg0: !torc
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(
// CHECK: torch.aten.append.t // CHECK: torch.aten.append.t
// CHECK: torch.aten.append.t // CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor {
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes { } shapes {
@ -247,13 +247,13 @@ func @abstractly_interpret_list_ops$use_of_alias$not_yet_handled(%arg0: !torch.v
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @abstractly_interpret_list_ops$readonly_op_in_child_region( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$readonly_op_in_child_region(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor { // CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.vtensor {
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int> // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT3]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int> // CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -276,9 +276,9 @@ func @abstractly_interpret_list_ops$readonly_op_in_child_region(%arg0: !torch.vt
} }
// The mutation in the child region prevents us from abstractly interpreting. // The mutation in the child region prevents us from abstractly interpreting.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mutation_in_child_region( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$mutation_in_child_region(
// CHECK: torch.aten.append.t // CHECK: torch.aten.append.t
func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtensor, %arg1: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -300,7 +300,7 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
// CHECK-LABEL: func @abstractly_interpret_list_ops$miscompile$list_identity( // CHECK-LABEL: func.func @abstractly_interpret_list_ops$miscompile$list_identity(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>, // CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor { // CHECK-SAME: %[[ARG2:.*]]: !torch.bool) -> !torch.vtensor {
@ -329,7 +329,7 @@ func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch.vtens
// CHECK: } : !torch.vtensor<[3,3],unk> // CHECK: } : !torch.vtensor<[3,3],unk>
// CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor // CHECK: %[[VAL_13:.*]] = torch.tensor_static_info_cast %[[VAL_14:.*]] : !torch.vtensor<[3,3],unk> to !torch.vtensor
// CHECK: return %[[VAL_13]] : !torch.vtensor // CHECK: return %[[VAL_13]] : !torch.vtensor
func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor { func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
@ -373,7 +373,7 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// This test should usually not be the one to catch an issue. // This test should usually not be the one to catch an issue.
// If it does catch an issue then it indicates a more precise unit test that is // If it does catch an issue then it indicates a more precise unit test that is
// missing. // missing.
// CHECK-LABEL: func @basic_integration( // CHECK-LABEL: func.func @basic_integration(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -388,7 +388,7 @@ func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.vtens
// CHECK: } : !torch.vtensor<[?,?],unk> // CHECK: } : !torch.vtensor<[?,?],unk>
// CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor // CHECK: %[[RESULT_ERASED:.*]] = torch.tensor_static_info_cast %[[RESULT:.*]] : !torch.vtensor<[?,?],unk> to !torch.vtensor
// CHECK: return %[[RESULT_ERASED]] : !torch.vtensor // CHECK: return %[[RESULT_ERASED]] : !torch.vtensor
func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
%true = torch.constant.bool true %true = torch.constant.bool true
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor %1 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],unk> -> !torch.vtensor

View File

@ -3,10 +3,10 @@
// This test is largely copied from `finalizing-bufferize` upstream, as it // This test is largely copied from `finalizing-bufferize` upstream, as it
// covers the same scope. // covers the same scope.
// CHECK-LABEL: func @eliminate_materializations( // CHECK-LABEL: func.func @eliminate_materializations(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: return %[[ARG]] : tensor<f32> // CHECK: return %[[ARG]] : tensor<f32>
func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> { func.func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32> %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
return %1 : tensor<f32> return %1 : tensor<f32>
@ -15,38 +15,38 @@ func @eliminate_materializations(%arg0: tensor<f32>) -> tensor<f32> {
// Do a basic check of other types. Under the hood they all take the same // Do a basic check of other types. Under the hood they all take the same
// code paths as for !torch.vtensor, so we just spot-check them here. // code paths as for !torch.vtensor, so we just spot-check them here.
// CHECK-LABEL: func @eliminate_materializations$torch.bool( // CHECK-LABEL: func.func @eliminate_materializations$torch.bool(
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: return %[[ARG]] : i1 // CHECK: return %[[ARG]] : i1
func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 { func.func @eliminate_materializations$torch.bool(%arg0: i1) -> i1 {
%0 = torch_c.from_i1 %arg0 %0 = torch_c.from_i1 %arg0
%1 = torch_c.to_i1 %0 %1 = torch_c.to_i1 %0
return %1 : i1 return %1 : i1
} }
// CHECK-LABEL: func @eliminate_materializations$torch.int( // CHECK-LABEL: func.func @eliminate_materializations$torch.int(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: return %[[ARG]] : i64 // CHECK: return %[[ARG]] : i64
func @eliminate_materializations$torch.int(%arg0: i64) -> i64 { func.func @eliminate_materializations$torch.int(%arg0: i64) -> i64 {
%0 = torch_c.from_i64 %arg0 %0 = torch_c.from_i64 %arg0
%1 = torch_c.to_i64 %0 %1 = torch_c.to_i64 %0
return %1 : i64 return %1 : i64
} }
// CHECK-LABEL: func @eliminate_materializations$torch.float( // CHECK-LABEL: func.func @eliminate_materializations$torch.float(
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: return %[[ARG]] : f64 // CHECK: return %[[ARG]] : f64
func @eliminate_materializations$torch.float(%arg0: f64) -> f64 { func.func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
%0 = torch_c.from_f64 %arg0 %0 = torch_c.from_f64 %arg0
%1 = torch_c.to_f64 %0 %1 = torch_c.to_f64 %0
return %1 : f64 return %1 : f64
} }
// CHECK-LABEL: func @eliminate_materializations$torch.Generator( // CHECK-LABEL: func.func @eliminate_materializations$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 { // CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64 // CHECK: return %[[VAL_0]] : i64
// CHECK: } // CHECK: }
func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 { func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
%0 = torch_c.i64_to_generator %arg0 %0 = torch_c.i64_to_generator %arg0
%1 = torch_c.generator_to_i64 %0 %1 = torch_c.generator_to_i64 %0
return %1 : i64 return %1 : i64
@ -54,7 +54,7 @@ func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
// ----- // -----
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> { func.func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// expected-error @+1 {{failed to legalize operation 'test.source'}} // expected-error @+1 {{failed to legalize operation 'test.source'}}
%0 = "test.source"() : () -> !torch.vtensor<[],f32> %0 = "test.source"() : () -> !torch.vtensor<[],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32> %1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[],f32> -> tensor<f32>
@ -63,7 +63,7 @@ func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
// ----- // -----
func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) { func.func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<f32> -> !torch.vtensor<[],f32>
// expected-error @+1 {{failed to legalize operation 'test.sink'}} // expected-error @+1 {{failed to legalize operation 'test.sink'}}
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> () "test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()

View File

@ -3,48 +3,48 @@
// This test is largely copied from `func-bufferize` upstream, as it covers // This test is largely copied from `func-bufferize` upstream, as it covers
// the same scope. // the same scope.
// CHECK-LABEL: func @identity( // CHECK-LABEL: func.func @identity(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: return %[[ARG]] : tensor<f32> // CHECK: return %[[ARG]] : tensor<f32>
func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func.func @identity(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
return %arg0 : !torch.vtensor<[],f32> return %arg0 : !torch.vtensor<[],f32>
} }
// CHECK-LABEL: func @block_arguments( // CHECK-LABEL: func.func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: cf.br ^bb1(%[[ARG]] : tensor<f32>) // CHECK: cf.br ^bb1(%[[ARG]] : tensor<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>): // CHECK: ^bb1(%[[BBARG:.*]]: tensor<f32>):
// CHECK: return %[[BBARG]] : tensor<f32> // CHECK: return %[[BBARG]] : tensor<f32>
func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func.func @block_arguments(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
cf.br ^bb1(%arg0: !torch.vtensor<[],f32>) cf.br ^bb1(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg: !torch.vtensor<[],f32>): ^bb1(%bbarg: !torch.vtensor<[],f32>):
return %bbarg : !torch.vtensor<[],f32> return %bbarg : !torch.vtensor<[],f32>
} }
// CHECK-LABEL: func private @source() -> tensor<f32> // CHECK-LABEL: func.func private @source() -> tensor<f32>
// CHECK-LABEL: func @call_source() -> tensor<f32> { // CHECK-LABEL: func.func @call_source() -> tensor<f32> {
// CHECK: %[[RET:.*]] = call @source() : () -> tensor<f32> // CHECK: %[[RET:.*]] = call @source() : () -> tensor<f32>
// CHECK: return %[[RET]] : tensor<f32> // CHECK: return %[[RET]] : tensor<f32>
func private @source() -> !torch.vtensor<[],f32> func.func private @source() -> !torch.vtensor<[],f32>
func @call_source() -> !torch.vtensor<[],f32> { func.func @call_source() -> !torch.vtensor<[],f32> {
%0 = call @source() : () -> !torch.vtensor<[],f32> %0 = call @source() : () -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
// CHECK-LABEL: func @call_sink( // CHECK-LABEL: func.func @call_sink(
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>) { // CHECK-SAME: %[[ARG:.*]]: tensor<f32>) {
// CHECK: call @sink(%[[ARG]]) : (tensor<f32>) -> () // CHECK: call @sink(%[[ARG]]) : (tensor<f32>) -> ()
// CHECK: return // CHECK: return
func private @sink(!torch.vtensor<[],f32>) func.func private @sink(!torch.vtensor<[],f32>)
func @call_sink(%arg0: !torch.vtensor<[],f32>) { func.func @call_sink(%arg0: !torch.vtensor<[],f32>) {
call @sink(%arg0) : (!torch.vtensor<[],f32>) -> () call @sink(%arg0) : (!torch.vtensor<[],f32>) -> ()
return return
} }
// CHECK-LABEL: func @unconverted_op_in_body() -> tensor<f32> { // CHECK-LABEL: func.func @unconverted_op_in_body() -> tensor<f32> {
// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32> // CHECK: %[[TENSOR:.*]] = "test.source"() : () -> !torch.vtensor<[],f32>
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: return %[[BUILTIN_TENSOR]] : tensor<f32> // CHECK: return %[[BUILTIN_TENSOR]] : tensor<f32>
func @unconverted_op_in_body() -> !torch.vtensor<[],f32> { func.func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
%0 = "test.source"() : () -> !torch.vtensor<[],f32> %0 = "test.source"() : () -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32>
} }
@ -53,7 +53,7 @@ func @unconverted_op_in_body() -> !torch.vtensor<[],f32> {
// Because this pass updates block arguments, it needs to also atomically // Because this pass updates block arguments, it needs to also atomically
// update all terminators and issue an error if that is not possible. // update all terminators and issue an error if that is not possible.
func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { func.func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = arith.constant true %0 = arith.constant true
cf.cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>) cf.cond_br %0, ^bb1(%arg0: !torch.vtensor<[],f32>), ^bb2(%arg0: !torch.vtensor<[],f32>)
^bb1(%bbarg0: !torch.vtensor<[],f32>): ^bb1(%bbarg0: !torch.vtensor<[],f32>):
@ -72,7 +72,7 @@ func @unable_to_update_terminator(%arg0: !torch.vtensor<[],f32>) -> !torch.vtens
// CHECK: while // CHECK: while
// CHECK: scf.while // CHECK: scf.while
// CHECK: scf.condition // CHECK: scf.condition
func @bwhile(%arg0: i64, %arg1: i64) -> i64 { func.func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
%c2_i64 = arith.constant 2 : i64 %c2_i64 = arith.constant 2 : i64
%0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) { %0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
%1 = arith.cmpi slt, %arg2, %arg1 : i64 %1 = arith.cmpi slt, %arg2, %arg1 : i64
@ -88,30 +88,30 @@ func @bwhile(%arg0: i64, %arg1: i64) -> i64 {
// Do a basic check of other types. Under the hood they all take the same // Do a basic check of other types. Under the hood they all take the same
// code paths as for !torch.vtensor, so we just spot-check them here. // code paths as for !torch.vtensor, so we just spot-check them here.
// CHECK-LABEL: func @identity$torch.bool( // CHECK-LABEL: func.func @identity$torch.bool(
// CHECK-SAME: %[[ARG:.*]]: i1) -> i1 { // CHECK-SAME: %[[ARG:.*]]: i1) -> i1 {
// CHECK: return %[[ARG]] : i1 // CHECK: return %[[ARG]] : i1
func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool { func.func @identity$torch.bool(%arg0: !torch.bool) -> !torch.bool {
return %arg0 : !torch.bool return %arg0 : !torch.bool
} }
// CHECK-LABEL: func @identity$torch.int( // CHECK-LABEL: func.func @identity$torch.int(
// CHECK-SAME: %[[ARG:.*]]: i64) -> i64 { // CHECK-SAME: %[[ARG:.*]]: i64) -> i64 {
// CHECK: return %[[ARG]] : i64 // CHECK: return %[[ARG]] : i64
func @identity$torch.int(%arg0: !torch.int) -> !torch.int { func.func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
return %arg0 : !torch.int return %arg0 : !torch.int
} }
// CHECK-LABEL: func @identity$torch.float( // CHECK-LABEL: func.func @identity$torch.float(
// CHECK-SAME: %[[ARG:.*]]: f64) -> f64 { // CHECK-SAME: %[[ARG:.*]]: f64) -> f64 {
// CHECK: return %[[ARG]] : f64 // CHECK: return %[[ARG]] : f64
func @identity$torch.float(%arg0: !torch.float) -> !torch.float { func.func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float return %arg0 : !torch.float
} }
// CHECK-LABEL: func @identity$torch.Generator( // CHECK-LABEL: func.func @identity$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 { // CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64 // CHECK: return %[[VAL_0]] : i64
func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator { func.func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator {
return %arg0 : !torch.Generator return %arg0 : !torch.Generator
} }

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s // RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
// CHECK-LABEL: func @builtin_tensor_interop( // CHECK-LABEL: func.func @builtin_tensor_interop(
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) { func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> // CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
%0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32> %0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],si8> // CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],si8>

View File

@ -2,7 +2,7 @@
// ----- // -----
func @unknown_rank(%arg0: !torch.vtensor<[],f32>) { func.func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}} // expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}} // expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<*,f32> %0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<*,f32>
@ -11,7 +11,7 @@ func @unknown_rank(%arg0: !torch.vtensor<[],f32>) {
// ----- // -----
func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) { func.func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
// expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}} // expected-error@+2 {{unsupported by backend lowering: tensor with unknown rank or dtype}}
// expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}} // expected-note@+1 {{this is likely due to a missing shape transfer function in shape_lib_gen.py}}
%0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],unk> %0 = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[],unk>
@ -20,7 +20,7 @@ func @unknown_dtype(%arg0: !torch.vtensor<[],f32>) {
// ----- // -----
func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) { func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int) {
// expected-error@+2 {{unsupported by backend lowering: `torch.operator` op}} // expected-error@+2 {{unsupported by backend lowering: `torch.operator` op}}
// expected-note@+1 {{this is likely due to a missing op that needs to be generated by torch_ods_gen.py}} // expected-note@+1 {{this is likely due to a missing op that needs to be generated by torch_ods_gen.py}}
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32> torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-linalg-on-tensors-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s // RUN: torch-mlir-opt -torch-verify-linalg-on-tensors-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @mm // CHECK: func.func @mm
func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { func.func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%cst = arith.constant 0.000000e+00 : f32 %cst = arith.constant 0.000000e+00 : f32
@ -23,7 +23,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}} // expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module { module {
func @disallowed() { func.func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}} // expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> () "unknown_dialect.unknown_op"() : () -> ()
return return
@ -46,7 +46,7 @@ module {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}} // expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module { module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor { func.func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'func.return'}} // expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
} }

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s // RUN: torch-mlir-opt -torch-verify-tosa-backend-contract -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
// CHECK: func @tanh // CHECK: func.func @tanh
func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { func.func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> %0 = "tosa.tanh"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }
@ -12,7 +12,7 @@ func @tanh(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}} // expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module { module {
func @disallowed() { func.func @disallowed() {
// expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}} // expected-error@+1 {{failed to legalize operation 'unknown_dialect.unknown_op'}}
"unknown_dialect.unknown_op"() : () -> () "unknown_dialect.unknown_op"() : () -> ()
return return
@ -35,7 +35,7 @@ module {
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}} // expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module { module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor { func.func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'func.return'}} // expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor return %arg0 : !torch.tensor
} }

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s // RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0> // CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
// CHECK-LABEL: func @f() -> i64 { // CHECK-LABEL: func.func @f() -> i64 {
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64> // CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64> // CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64 // CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
@ -13,7 +13,7 @@
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64> // CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64 // CHECK: return %[[NEXT_SEED]] : i64
module { module {
func @f() -> i64 { func.func @f() -> i64 {
%seed = torch_c.get_next_seed : () -> i64 %seed = torch_c.get_next_seed : () -> i64
return %seed : i64 return %seed : i64
} }

View File

@ -1,43 +1,43 @@
// RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s // RUN: torch-mlir-opt %s -refback-munge-calling-conventions -split-input-file | FileCheck %s
// CHECK-LABEL: func @f( // CHECK-LABEL: func.func @f(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return // CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> { func.func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32> return %arg0 : memref<?xf32>
} }
// ----- // -----
// CHECK-LABEL: func @i( // CHECK-LABEL: func.func @i(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> () // CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return // CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> { func.func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64> return %arg0 : memref<?xi64>
} }
// ----- // -----
// CHECK-LABEL: func @elemental_type( // CHECK-LABEL: func.func @elemental_type(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64> // CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> () // CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
// CHECK: return // CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 { func.func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64> %0 = memref.load %arg0[] : memref<i64>
return %0 : i64 return %0 : i64
} }
// ----- // -----
// CHECK-LABEL: func @multiple_return_values( // CHECK-LABEL: func.func @multiple_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>, // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
@ -50,13 +50,13 @@ func @elemental_type(%arg0: memref<i64>) -> i64 {
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> () // CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
// CHECK: return // CHECK: return
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) { func.func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32> return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
} }
// ----- // -----
// CHECK-LABEL: func @two_return_values( // CHECK-LABEL: func.func @two_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xi64>) // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xi64>)
// CHECK-SAME: attributes {llvm.emit_c_interface} { // CHECK-SAME: attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
@ -67,6 +67,6 @@ func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2:
// CHECK-SAME: : (memref<*xf32>, memref<*xi64>) -> () // CHECK-SAME: : (memref<*xf32>, memref<*xi64>) -> ()
// CHECK: return // CHECK: return
func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) { func.func @two_return_values(%arg0: memref<?xf32>, %arg1: memref<?xi64>) -> (memref<?xf32>, memref<?xi64>) {
return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64> return %arg0 ,%arg1 : memref<?xf32>, memref<?xi64>
} }

View File

@ -21,7 +21,7 @@ recursivescriptmodule = torch.jit.script(test_module)
annotator = ClassAnnotator() annotator = ClassAnnotator()
class_type = recursivescriptmodule._c._type() class_type = recursivescriptmodule._c._type()
# CHECK: func private @__torch__.TestModule.forward( # CHECK: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">, # CHECK-SAME: %arg0: !torch.nn.Module<"__torch__.TestModule">,
# CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>}, # CHECK-SAME: %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,1024],si8>},
# CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>} # CHECK-SAME: %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[],f32>}

View File

@ -13,18 +13,18 @@ mb = ModuleBuilder()
# Interesting test case, where a function calls a method. # Interesting test case, where a function calls a method.
# CHECK-LABEL: func private @__torch__.TestModule.forward # CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none { # CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[F:.*]] = constant @__torch__.calls_method : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none # CHECK: %[[F:.*]] = constant @__torch__.calls_method : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none
# CHECK: %[[RET:.*]] = call_indirect %[[F]](%[[ARG0]], %[[ARG1]]) : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none # CHECK: %[[RET:.*]] = call_indirect %[[F]](%[[ARG0]], %[[ARG1]]) : (!torch.nn.Module<"__torch__.TestModule">, !torch.tensor) -> !torch.none
# CHECK: return %[[RET]] : !torch.none # CHECK: return %[[RET]] : !torch.none
# CHECK: } # CHECK: }
# CHECK-LABEL: func private @__torch__.TestModule.method # CHECK-LABEL: func.func private @__torch__.TestModule.method
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none { # CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[RET:.*]] = torch.constant.none # CHECK: %[[RET:.*]] = torch.constant.none
# CHECK: return %[[RET]] : !torch.none # CHECK: return %[[RET]] : !torch.none
# CHECK: } # CHECK: }
# CHECK-LABEL: func private @__torch__.calls_method # CHECK-LABEL: func.func private @__torch__.calls_method
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none { # CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[ARG0]]["method"] (%[[ARG1]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.tensor) -> !torch.none # CHECK: %[[RET:.*]] = torch.prim.CallMethod %[[ARG0]]["method"] (%[[ARG1]]) : !torch.nn.Module<"__torch__.TestModule">, (!torch.tensor) -> !torch.none
# CHECK: return %[[RET]] : !torch.none # CHECK: return %[[RET]] : !torch.none

View File

@ -11,13 +11,13 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func private @__torch__.TestModule.forward # CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor { # CHECK-SAME: (%[[ARG0:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[ARG1:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[VAL_2:.*]] = constant @__torch__.identity : (!torch.tensor) -> !torch.tensor # CHECK: %[[VAL_2:.*]] = constant @__torch__.identity : (!torch.tensor) -> !torch.tensor
# CHECK: %[[VAL_3:.*]] = call_indirect %[[VAL_2]](%[[ARG1]]) : (!torch.tensor) -> !torch.tensor # CHECK: %[[VAL_3:.*]] = call_indirect %[[VAL_2]](%[[ARG1]]) : (!torch.tensor) -> !torch.tensor
# CHECK: return %[[VAL_3]] : !torch.tensor # CHECK: return %[[VAL_3]] : !torch.tensor
# CHECK: } # CHECK: }
# CHECK-LABEL: func private @__torch__.identity # CHECK-LABEL: func.func private @__torch__.identity
# CHECK-SAME: (%[[ARG:.*]]: !torch.tensor) -> !torch.tensor { # CHECK-SAME: (%[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: return %[[ARG]] : !torch.tensor # CHECK: return %[[ARG]] : !torch.tensor
# CHECK: } # CHECK: }

View File

@ -17,7 +17,7 @@ class TestModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# CHECK-LABEL: func private @__torch__.TestModule.forward( # CHECK-LABEL: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<int> { # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int> # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>

View File

@ -21,7 +21,7 @@ mb = ModuleBuilder()
# Given how systematic this is, we don't treat the symbol names as opaque (i.e. # Given how systematic this is, we don't treat the symbol names as opaque (i.e.
# we don't need to capture their names when FileCheck testing). # we don't need to capture their names when FileCheck testing).
# CHECK-LABEL: func private @__torch__.TestModule.forward # CHECK-LABEL: func.func private @__torch__.TestModule.forward
# CHECK-SAME: (%[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[X:.*]]: !torch.tensor) -> !torch.tensor { # CHECK-SAME: (%[[SELF:.*]]: !torch.nn.Module<"__torch__.TestModule">, %[[X:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: return %[[X]] : !torch.tensor # CHECK: return %[[X]] : !torch.tensor
# CHECK: } # CHECK: }

View File

@ -17,7 +17,7 @@ class TestModule(torch.nn.Module):
self.t1 = torch.ones(1) self.t1 = torch.ones(1)
self.t2 = torch.ones(1) self.t2 = torch.ones(1)
# CHECK-LABEL: func private @__torch__.TestModule.forward( # CHECK-LABEL: func.func private @__torch__.TestModule.forward(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !torch.none { # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">) -> !torch.none {
def forward(self): def forward(self):
# CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"] # CHECK: %[[T2:.*]] = torch.prim.GetAttr %[[SELF]]["t2"]
@ -25,7 +25,7 @@ class TestModule(torch.nn.Module):
self.t1 = self.t2 self.t1 = self.t2
# CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}}) # CHECK: torch.prim.CallMethod %[[SELF]]["callee"] (%{{.*}}, %{{.*}})
self.callee(self.t1, self.t2) self.callee(self.t1, self.t2)
# CHECK-LABEL: func private @__torch__.TestModule.callee( # CHECK-LABEL: func.func private @__torch__.TestModule.callee(
# CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">, # CHECK-SAME: %[[SELF:.*]]: !torch.nn.Module<"{{.*}}">,
# CHECK-SAME: %[[X:.*]]: !torch.tensor, # CHECK-SAME: %[[X:.*]]: !torch.tensor,
# CHECK-SAME: %[[Y:.*]]: !torch.tensor # CHECK-SAME: %[[Y:.*]]: !torch.tensor

View File

@ -24,7 +24,7 @@ class TestModule(torch.nn.Module):
self.s1 = Submodule(1) self.s1 = Submodule(1)
self.s2 = Submodule(2) self.s2 = Submodule(2)
# CHECK-LABEL: func private @{{.*}}TestModule.forward # CHECK-LABEL: func.func private @{{.*}}TestModule.forward
def forward(self, b: bool): def forward(self, b: bool):
# Modules with the same class can be selected between. # Modules with the same class can be selected between.
# CHECK: %[[MOD:.*]] = torch.prim.If # CHECK: %[[MOD:.*]] = torch.prim.If

View File

@ -18,7 +18,7 @@ class BasicClass:
def __init__(self, x: int): def __init__(self, x: int):
self.x = x self.x = x
# CHECK-LABEL: func @__torch__.prim_CreateObject( # CHECK-LABEL: func.func @__torch__.prim_CreateObject(
# CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.nn.Module<"__torch__.BasicClass"> { # CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.nn.Module<"__torch__.BasicClass"> {
# CHECK: %[[OBJECT:.*]] = torch.prim.CreateObject !torch.nn.Module<"__torch__.BasicClass"> # CHECK: %[[OBJECT:.*]] = torch.prim.CreateObject !torch.nn.Module<"__torch__.BasicClass">
# CHECK: %[[NONE:.*]] = torch.prim.CallMethod %[[OBJECT]]["__init__"] (%[[ARG0]]) : !torch.nn.Module<"__torch__.BasicClass">, (!torch.int) -> !torch.none # CHECK: %[[NONE:.*]] = torch.prim.CallMethod %[[OBJECT]]["__init__"] (%[[ARG0]]) : !torch.nn.Module<"__torch__.BasicClass">, (!torch.int) -> !torch.none

View File

@ -9,7 +9,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.add3 # CHECK-LABEL: func.func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch # Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently # graph are ascribed to the first op that carries source information. Presently
# this includes naked constants, return and the function itself. This heuristic # this includes naked constants, return and the function itself. This heuristic

View File

@ -12,7 +12,7 @@ from typing import Tuple, Optional, List, NamedTuple, Dict
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<str, tensor> { # CHECK-LABEL: func.func @__torch__.dict_literal_empty() -> !torch.dict<str, tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<str, tensor> # CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<str, tensor>
# CHECK: return %[[DICT]] : !torch.dict<str, tensor> # CHECK: return %[[DICT]] : !torch.dict<str, tensor>
@mb.import_function @mb.import_function
@ -21,7 +21,7 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]:
return {} return {}
# CHECK-LABEL: func @__torch__.dict_literal( # CHECK-LABEL: func.func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor, # CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor) # CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<str, optional<tensor>> { # CHECK-SAME: -> !torch.dict<str, optional<tensor>> {

View File

@ -10,7 +10,7 @@ from utils import create_script_function
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.refined_block_arg( # CHECK-LABEL: func.func @__torch__.refined_block_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32> # CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
# CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor # CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor

View File

@ -11,7 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return( # CHECK-LABEL: func.func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int> # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>
# CHECK: return %[[RET]] : !torch.optional<int> # CHECK: return %[[RET]] : !torch.optional<int>
@ -20,14 +20,14 @@ mb = ModuleBuilder()
def optional_return(i: int) -> typing.Optional[int]: def optional_return(i: int) -> typing.Optional[int]:
return i return i
# CHECK-LABEL: func @__torch__.optional_arg( # CHECK-LABEL: func.func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.none { # CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.none {
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None: def optional_arg(i: typing.Optional[int]) -> None:
return return
# CHECK-LABEL: func @__torch__.calls_optional_arg( # CHECK-LABEL: func.func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.none { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.none {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<int>) -> !torch.none # CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<int>) -> !torch.none
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int> # CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional<int>

View File

@ -32,7 +32,7 @@ def prim_If(b: bool, i: int):
else: else:
return i * i return i * i
# CHECK-LABEL: func @__torch__.prim_If_derefine( # CHECK-LABEL: func.func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[B:.*]]: !torch.bool,
# CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional<int> { # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[NONE:.*]] = torch.constant.none

View File

@ -9,7 +9,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.f( # CHECK-LABEL: func.func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<tensor> { # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<tensor> {
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<tensor> # CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<tensor>

View File

@ -11,7 +11,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.prim_Loop_forlike( # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true
# CHECK: %[[F_INIT:.*]] = torch.constant.float 0.000000e+00 # CHECK: %[[F_INIT:.*]] = torch.constant.float 0.000000e+00
@ -29,7 +29,7 @@ def prim_Loop_forlike(n: int):
f += i f += i
return f return f
# CHECK-LABEL: func @__torch__.prim_Loop_whilelike( # CHECK-LABEL: func.func @__torch__.prim_Loop_whilelike(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.float { # CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.float {
# CHECK: %[[F_INIT:.*]] = torch.constant.float 3.200000e+00 # CHECK: %[[F_INIT:.*]] = torch.constant.float 3.200000e+00
# CHECK: %[[MAX_ITERATIONS:.*]] = torch.constant.int 9223372036854775807 # CHECK: %[[MAX_ITERATIONS:.*]] = torch.constant.int 9223372036854775807
@ -49,7 +49,7 @@ def prim_Loop_whilelike(n: int):
f = f * f f = f * f
return f return f
# CHECK-LABEL: func @__torch__.prim_Loop_derefine( # CHECK-LABEL: func.func @__torch__.prim_Loop_derefine(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional<int> {
# CHECK: %[[TRUE:.*]] = torch.constant.bool true # CHECK: %[[TRUE:.*]] = torch.constant.bool true
# CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[NONE:.*]] = torch.constant.none

View File

@ -15,7 +15,7 @@ import typing
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.prim_NumToTensor( # CHECK-LABEL: func.func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor # CHECK: return %[[RET]] : !torch.tensor
@ -25,7 +25,7 @@ mb = ModuleBuilder()
def prim_NumToTensor(i: int): def prim_NumToTensor(i: int):
return _to_tensor(i) return _to_tensor(i)
# CHECK-LABEL: func @__torch__.prim_Print( # CHECK-LABEL: func.func @__torch__.prim_Print(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.none { # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.none {
# CHECK: %[[STR:.*]] = torch.constant.str "x" # CHECK: %[[STR:.*]] = torch.constant.str "x"
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !torch.str, !torch.tensor # CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !torch.str, !torch.tensor
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
def prim_Print(x): def prim_Print(x):
print("x", x) print("x", x)
# CHECK-LABEL: func @__torch__.prim_RaiseException() -> !torch.none { # CHECK-LABEL: func.func @__torch__.prim_RaiseException() -> !torch.none {
# CHECK: %[[ERRORSTR:.*]] = torch.constant.str "Error" # CHECK: %[[ERRORSTR:.*]] = torch.constant.str "Error"
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !torch.none # CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !torch.none
# CHECK: torch.prim.RaiseException %[[ERRORSTR]] # CHECK: torch.prim.RaiseException %[[ERRORSTR]]
@ -44,7 +44,7 @@ def prim_Print(x):
def prim_RaiseException(): def prim_RaiseException():
raise Exception("Error") raise Exception("Error")
# CHECK-LABEL: func @__torch__.prim_unchecked_cast( # CHECK-LABEL: func.func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.int { # CHECK-SAME: %[[ARG:.*]]: !torch.optional<int>) -> !torch.int {
# CHECK: %[[NONE:.*]] = torch.constant.none # CHECK: %[[NONE:.*]] = torch.constant.none
# CHECK: %[[C3:.*]] = torch.constant.int 3 # CHECK: %[[C3:.*]] = torch.constant.int 3
@ -63,7 +63,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return 3 return 3
return i return i
# CHECK-LABEL: func @__torch__.prim_TupleUnpack( # CHECK-LABEL: func.func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<int, int>) -> !torch.int { # CHECK-SAME: %[[ARG:.*]]: !torch.tuple<int, int>) -> !torch.int {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<int, int> -> !torch.int, !torch.int # CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<int, int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#0 : !torch.int # CHECK: return %[[RET]]#0 : !torch.int
@ -73,7 +73,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
val, _ = tup val, _ = tup
return val return val
# CHECK-LABEL: func @__torch__.prim_TupleIndex( # CHECK-LABEL: func.func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<tensor, tensor>) -> !torch.tensor { # CHECK-SAME: %[[ARG:.*]]: !torch.tuple<tensor, tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor # CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor # CHECK: return %[[RET]] : !torch.tensor
@ -82,7 +82,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]): def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
return tup[0] return tup[0]
# CHECK-LABEL: func @__torch__.prim_ListUnpack( # CHECK-LABEL: func.func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !torch.list<int>) -> !torch.int { # CHECK-SAME: %[[ARG:.*]]: !torch.list<int>) -> !torch.int {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<int> -> !torch.int, !torch.int # CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<int> -> !torch.int, !torch.int
# CHECK: return %[[RET]]#1 : !torch.int # CHECK: return %[[RET]]#1 : !torch.int
@ -92,7 +92,7 @@ def prim_ListUnpack(l: typing.List[int]):
_, val, _ = l _, val, _ = l
return val return val
# CHECK-LABEL: func @__torch__.prim_dtype( # CHECK-LABEL: func.func @__torch__.prim_dtype(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int { # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int {
# CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !torch.tensor -> !torch.int # CHECK: %[[RET:.*]] = torch.prim.dtype %[[ARG]] : !torch.tensor -> !torch.int
# CHECK: return %[[RET]] : !torch.int # CHECK: return %[[RET]] : !torch.int
@ -101,7 +101,7 @@ def prim_ListUnpack(l: typing.List[int]):
def prim_dtype(x): def prim_dtype(x):
return x.dtype return x.dtype
# CHECK-LABEL: func @__torch__.prim_layout( # CHECK-LABEL: func.func @__torch__.prim_layout(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int { # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.int {
# CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !torch.tensor -> !torch.int # CHECK: %[[RET:.*]] = torch.prim.layout %[[ARG]] : !torch.tensor -> !torch.int
# CHECK: return %[[RET]] : !torch.int # CHECK: return %[[RET]] : !torch.int
@ -110,7 +110,7 @@ def prim_dtype(x):
def prim_layout(x): def prim_layout(x):
return x.layout return x.layout
# CHECK-LABEL: func @__torch__.prim_device( # CHECK-LABEL: func.func @__torch__.prim_device(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.Device { # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.Device {
# CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !torch.tensor -> !torch.Device # CHECK: %[[RET:.*]] = torch.prim.device %[[ARG]] : !torch.tensor -> !torch.Device
# CHECK: return %[[RET]] : !torch.Device # CHECK: return %[[RET]] : !torch.Device
@ -119,7 +119,7 @@ def prim_layout(x):
def prim_device(x): def prim_device(x):
return x.device return x.device
# CHECK-LABEL: func @__torch__.prim_min( # CHECK-LABEL: func.func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int> # CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int # CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
@ -133,7 +133,7 @@ def prim_device(x):
def prim_min(x: int): def prim_min(x: int):
return min(x), min(x,x), min(x, x, x) return min(x), min(x,x), min(x, x, x)
# CHECK-LABEL: func @__torch__.prim_max( # CHECK-LABEL: func.func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> { # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tuple<int, int, int> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int> # CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int # CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<int> -> !torch.int
@ -147,7 +147,7 @@ def prim_min(x: int):
def prim_max(x: int): def prim_max(x: int):
return max(x), max(x,x), max(x, x, x) return max(x), max(x,x), max(x, x, x)
# CHECK-LABEL: func @__torch__.prim_Constant_list() -> !torch.list<int> { # CHECK-LABEL: func.func @__torch__.prim_Constant_list() -> !torch.list<int> {
# CHECK: %[[A:.*]] = torch.constant.int 1 # CHECK: %[[A:.*]] = torch.constant.int 1
# CHECK: %[[B:.*]] = torch.constant.int 2 # CHECK: %[[B:.*]] = torch.constant.int 2
# CHECK: %[[C:.*]] = torch.constant.int 3 # CHECK: %[[C:.*]] = torch.constant.int 3

View File

@ -14,7 +14,7 @@ mb = ModuleBuilder()
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]), NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
('f2', Optional[torch.Tensor])]) ('f2', Optional[torch.Tensor])])
# CHECK-LABEL: func @__torch__.tuple( # CHECK-LABEL: func.func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> # CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<tensor, tensor> { # CHECK-SAME: !torch.tuple<tensor, tensor> {
@ -27,7 +27,7 @@ def tuple(t0, t1):
return t0, t1 return t0, t1
# CHECK-LABEL: func @__torch__.tuple_optional( # CHECK-LABEL: func.func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> # CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> { # CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
@ -44,7 +44,7 @@ def tuple_optional(
return t0, t1 return t0, t1
# CHECK-LABEL: func @__torch__.namedtuple_optional( # CHECK-LABEL: func.func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> # CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> { # CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>> {
@ -59,7 +59,7 @@ def namedtuple_optional(
return NT(t0, t1) return NT(t0, t1)
# CHECK-LABEL: func @__torch__.tuple_construct_arg_needs_refinement( # CHECK-LABEL: func.func @__torch__.tuple_construct_arg_needs_refinement(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<tensor, tensor> { # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<tensor, tensor> {
# CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32> # CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32>

View File

@ -11,7 +11,7 @@ from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
mb = ModuleBuilder() mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.f( # CHECK-LABEL: func.func @__torch__.f(
# CHECK-SAME: %{{.*}}: !torch.union<float, int>) -> !torch.none { # CHECK-SAME: %{{.*}}: !torch.union<float, int>) -> !torch.none {
@mb.import_function @mb.import_function