torch-mlir/test/Dialect/TMTensor/convert_to_loops.mlir

368 lines
17 KiB
MLIR

// RUN: torch-mlir-opt -split-input-file -tm-tensor-to-loops %s | FileCheck %s
func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func.func @scan_1d_inclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func.func @scan_1d_exclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
tm_tensor.scan dimension(0) inclusive(false)
ins(%0 : memref<128xi32>) outs(%1, %c0 : memref<128xi32>, memref<i32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func.func @scan_1d_exclusive
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<i32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V0:.+]] = memref.load %[[ACC]][] : memref<i32>
// CHECK: memref.store %[[V0]], %[[BUFO]][%[[ARG1]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[T1]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]]]
// CHECK: memref.store %[[V4]], %[[ACC]][]
// CHECK: }
// -----
func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) {
%t0 = memref.alloc() : memref<32xi32>
tm_tensor.scan dimension(0) inclusive(true)
ins(%0 : memref<16x32xi32>) outs(%1, %t0 : memref<16x32xi32>, memref<32xi32>) {
^bb0(%arg0 : i32, %arg1 : i32):
%sum = arith.addi %arg0, %arg1 : i32
tm_tensor.yield %sum : i32
}
return
}
// CHECK-LABEL: func.func @scan_2d
// CHECK-SAME: %[[BUFI:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[BUFO:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[ACC:.+]] = memref.alloc() : memref<32xi32>
// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[ARG1]], %[[C0]] : index
// CHECK: scf.if %[[COND]] {
// CHECK: %[[V1:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V1]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: } else {
// CHECK: %[[T1:.+]] = arith.subi %[[ARG1]], %[[C1]] : index
// CHECK: %[[V2:.+]] = memref.load %[[BUFO]][%[[T1]], %[[ARG2]]]
// CHECK: %[[V3:.+]] = memref.load %[[BUFI]][%[[ARG1]], %[[ARG2]]]
// CHECK: %[[V4:.+]] = arith.addi %[[V2]], %[[V3]] : i32
// CHECK: memref.store %[[V4]], %[[BUFO]][%[[ARG1]], %[[ARG2]]]
// CHECK: memref.store %[[V4]], %[[ACC]][%[[ARG2]]]
// CHECK: }
// -----
func.func @scatter_update_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func.func @scatter_update_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func.func @scatter_add_scalar_2D(
%original: memref<4x3xi32>, %indices: memref<3x2xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func.func @scatter_add_scalar_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func.func @scatter_update_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK: func.func @scatter_update_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
// CHECK: }
// CHECK: }
// -----
func.func @scatter_add_scalar_1D(
%original: memref<8xi32>, %indices: memref<3x1xi32>,
%updates: memref<3xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
outs(%original : memref<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func.func @scatter_add_scalar_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
// -----
func.func @scatter_add_slice_2D(
%original: memref<4x3xi32>, %indices: memref<2x1xi32>,
%updates: memref<2x3xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
outs(%original : memref<4x3xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK: func.func @scatter_add_slice_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func.func @scatter_update_scalar_dynamic_1D(
%original: memref<?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK-LABEL: func.func @scatter_update_scalar_dynamic_1D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
// -----
func.func @scatter_add_scalar_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x2xi32>,
%updates: memref<?xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%0 = arith.addi %arg1, %arg0 : i32
tm_tensor.yield %0 : i32
}
return
}
// CHECK-LABEL: func.func @scatter_add_scalar_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
// -----
func.func @scatter_update_slice_dynamic_2D(
%original: memref<?x?xi32>, %indices: memref<?x1xi32>,
%updates: memref<?x?xi32>) {
tm_tensor.scatter unique_indices(true)
ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
outs(%original : memref<?x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
tm_tensor.yield %arg0 : i32
}
return
}
// CHECK: func.func @scatter_update_slice_dynamic_2D
// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
// -----
func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
tm_tensor.scatter
unique_indices(true)
ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
outs(%arg0 : memref<2x64x12xf32>) {
^bb0(%arg3: f32, %arg4: f32):
tm_tensor.yield %arg4 : f32
}
return
}
// CHECK-LABEL: func.func @scatter_partial_slices
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-DAG: %[[C1:.+]] = arith.constant
// CHECK-DAG: %[[C2:.+]] = arith.constant
// CHECK-DAG: %[[C12:.+]] = arith.constant
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>