// RUN: torch-mlir-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) { %c0 = memref.alloc() : memref tm_tensor.scan dimension(0) inclusive(true) ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref) { ^bb0(%arg0 : i32, %arg1 : i32): %sum = arith.addi %arg0, %arg1 : i32 tm_tensor.yield %sum : i32 } return } // CHECK-LABEL: func.func @scan_1d_inclusive // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref // CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] // CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index // CHECK: scf.if %[[COND]] { // CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]] // CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]] // CHECK: } else { // CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index // CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]] // CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]] // CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 // CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]] // CHECK: memref.store %[[V4]], %[[ACC]][] // CHECK: } // ----- func.func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) { %c0 = memref.alloc() : memref tm_tensor.scan dimension(0) inclusive(false) ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref) { ^bb0(%arg0 : i32, %arg1 : i32): %sum = arith.addi %arg0, %arg1 : i32 tm_tensor.yield %sum : i32 } return } // CHECK-LABEL: func.func @scan_1d_exclusive // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref // CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]] // CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index // CHECK: scf.if %[[COND]] { // CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref // CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]] // CHECK: } else { // CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index // CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]] // CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]] // CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 // CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]] // CHECK: memref.store %[[V4]], %[[ACC]][] // CHECK: } // ----- func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { %t0 = memref.alloc() : memref<32xi32> tm_tensor.scan dimension(0) inclusive(true) ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) { ^bb0(%arg0 : i32, %arg1 : i32): %sum = arith.addi %arg0, %arg1 : i32 tm_tensor.yield %sum : i32 } return } // CHECK-LABEL: func.func @scan_2d // CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]] // CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32> // CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]] // CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]] // CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index // CHECK: scf.if %[[COND]] { // CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]] // CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]] // CHECK: } else { // CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index // CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]] // CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]] // CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32 // CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]] // CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]] // CHECK: } // ----- func.func @scatter_update_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors tm_tensor.yield %arg0 : i32 } return } // CHECK-LABEL: func.func @scatter_update_scalar_1D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> // CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32> // CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index // CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]] // ----- func.func @scatter_add_scalar_2D( %original: memref<4x3xi32>, %indices: memref<3x2xi32>, %updates: memref<3xi32>) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = arith.addi %arg1, %arg0 : i32 tm_tensor.yield %0 : i32 } return } // CHECK-LABEL: func.func @scatter_add_scalar_2D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> // CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32> // CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index // CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32> // CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index // CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32> // CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 // CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] // ----- func.func @scatter_update_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors tm_tensor.yield %arg0 : i32 } return } // CHECK: func.func @scatter_update_slice_2D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { // CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] // CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] // CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index // CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]] // CHECK: } // CHECK: } // ----- func.func @scatter_add_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = arith.addi %arg1, %arg0 : i32 tm_tensor.yield %0 : i32 } return } // CHECK-LABEL: func.func @scatter_add_scalar_1D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32> // CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32> // CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index // CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32> // CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 // CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]] // ----- func.func @scatter_add_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = arith.addi %arg1, %arg0 : i32 tm_tensor.yield %0 : i32 } return } // CHECK: func.func @scatter_add_slice_2D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { // CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] // CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] // CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index // CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]] // CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]] // CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]] // ----- func.func @scatter_update_scalar_dynamic_1D( %original: memref, %indices: memref, %updates: memref) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors tm_tensor.yield %arg0 : i32 } return } // CHECK-LABEL: func.func @scatter_update_scalar_dynamic_1D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] { // CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref // CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref // CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index // CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]] // ----- func.func @scatter_add_scalar_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors %0 = arith.addi %arg1, %arg0 : i32 tm_tensor.yield %0 : i32 } return } // CHECK-LABEL: func.func @scatter_add_scalar_dynamic_2D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] { // CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref // CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref // CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index // CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref // CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index // CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref // CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32 // CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] // ----- func.func @scatter_update_slice_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { tm_tensor.scatter unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors tm_tensor.yield %arg0 : i32 } return } // CHECK: func.func @scatter_update_slice_dynamic_2D // CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] // CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] // CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref // CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref // CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] { // CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] { // CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]] // CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] // CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index // CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]] // ----- func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { tm_tensor.scatter unique_indices(true) ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) outs(%arg0 : memref<2x64x12xf32>) { ^bb0(%arg3: f32, %arg4: f32): tm_tensor.yield %arg4 : f32 } return } // CHECK-LABEL: func.func @scatter_partial_slices // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] // CHECK-DAG: %[[C0:.+]] = arith.constant // CHECK-DAG: %[[C1:.+]] = arith.constant // CHECK-DAG: %[[C2:.+]] = arith.constant // CHECK-DAG: %[[C12:.+]] = arith.constant // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { // CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { // CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] { // CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32> // CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index // CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32> // CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index // CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index // CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32> // CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index // CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index // CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32> // CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>