mirror of https://github.com/llvm/torch-mlir
[onnx] Simplify onnx.slice lowering (#2919)
Onnx slice lowering used arange needlessly instead of directly constructing the constant dimension values. This makes lowerings to linalg struggle as multiple folders are required to get what is a constant index value.pull/2924/merge
parent
fd08578bdb
commit
cea51897a5
|
@ -1540,15 +1540,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (binder.tensorOperandAtIndex(axes, 3)) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
// The default axes value is the range from 0 to the size of first
|
||||
// dimension of `starts` and `ends`.
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
|
||||
axes = rewriter.create<Torch::AtenArangeOp>(
|
||||
loc, startsTorchTy, arangeLength, none, none, none, none);
|
||||
}
|
||||
|
||||
// Binding `steps` from its arguments or through a default value
|
||||
|
@ -1579,6 +1570,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, "Expected the rank of starts and ends tensors to be 1 "
|
||||
"and their dimensions to match");
|
||||
|
||||
if (axes) {
|
||||
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
|
||||
auto axesTy =
|
||||
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
||||
|
@ -1587,6 +1579,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (!(axesTy && numAxes == endSize))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Axes should be the same size of starts and ends");
|
||||
}
|
||||
|
||||
auto stepsTy = steps.getType()
|
||||
.cast<Torch::ValueTensorType>()
|
||||
|
@ -1622,7 +1615,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
}
|
||||
auto intermediateType = Torch::ValueTensorType::get(
|
||||
context, intermediateShape, resultTorchType.getOptionalDtype());
|
||||
for (int i = 0; i < numAxes; ++i) {
|
||||
for (int i = 0; i < endSize; ++i) {
|
||||
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
|
@ -1636,12 +1629,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
Value start = select(starts, kTensor);
|
||||
Value end = select(ends, kTensor);
|
||||
Value axis = select(axes, kTensor);
|
||||
Value axis = axes ? select(axes, kTensor) : k;
|
||||
Value step = select(steps, kTensor);
|
||||
|
||||
auto sliceType = intermediateType;
|
||||
if (i == numAxes - 1)
|
||||
sliceType = resultTorchType;
|
||||
sliceType = i == (endSize - 1) ? resultTorchType : sliceType;
|
||||
operand = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
loc, sliceType, operand, axis, start, end, step);
|
||||
}
|
||||
|
|
|
@ -2101,10 +2101,6 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxSignedIntModule_basic",
|
||||
"ReduceMaxUnsignedIntModule_basic",
|
||||
|
||||
# Failure - slice_lowering
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
|
||||
# Failure - view_lowering
|
||||
"AddSizeIntModule_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
|
|
|
@ -1157,9 +1157,6 @@ func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtenso
|
|||
func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
//CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
//CHECK: %[[AXES_DEFAULT_SIZE:.*]] = torch.constant.int 3
|
||||
//CHECK: %[[DEFAULT_AXES:.*]] = torch.aten.arange %[[AXES_DEFAULT_SIZE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
|
||||
//CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||
//CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3
|
||||
//CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list<int>
|
||||
//CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
|
||||
//CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0
|
||||
|
@ -1170,11 +1167,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
|||
//CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32>
|
||||
//CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[CONST_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32>
|
||||
|
||||
//CHECK: %[[CONST_1:.*]] = torch.constant.int 1
|
||||
//CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
|
@ -1182,11 +1177,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
|||
//CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32>
|
||||
//CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[CONST_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32>
|
||||
|
||||
//CHECK: %[[CONST_2:.*]] = torch.constant.int 2
|
||||
//CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
|
@ -1194,11 +1187,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
|||
//CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
//CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
//CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||
//CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[CONST_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||
%0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32>
|
||||
return %0 : !torch.vtensor<[20,10,1],f32>
|
||||
}
|
||||
|
@ -1211,17 +1202,15 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
|||
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64>
|
||||
|
||||
// CHECK: %[[ZERO0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||
// CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||
|
||||
func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
%0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32>
|
||||
|
|
Loading…
Reference in New Issue