Support select_last_index attribute of onnx argmax op (#3192)

The tests listed in https://github.com/nod-ai/SHARK-Turbine/issues/635
all compiled, but having run issue of dtype mismatch of i/si.
pull/3212/head
jinchen 2024-04-23 10:16:08 -07:00 committed by GitHub
parent ddb29c2c02
commit 61e6312c87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 19 deletions

View File

@ -101,17 +101,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.s64BoolAttr(selectLastIndex, "select_last_index", false))
return failure();
if (selectLastIndex) {
// TODO: Figure out how to support this case. Need to add a reverse
// or something.
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: select_last_index=true");
}
// ONNX allows negative axis.
auto operandSizes =
cast<Torch::ValueTensorType>(operand.getType()).getSizes();
if (axis < 0)
axis +=
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
axis += operandSizes.size();
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@ -119,6 +113,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value constKeepDims = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(keepDims));
if (selectLastIndex) {
Value dims = createConstantIntList(binder, rewriter, {axis});
auto operandTy = dyn_cast<Torch::ValueTensorType>(operand.getType());
operand = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), operandTy, operand, dims);
Value argmax = rewriter.create<Torch::AtenArgmaxOp>(
binder.getLoc(), resultType, operand, constAxis, constKeepDims);
Value offset = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(operandSizes[axis] - 1));
Value alpha = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, argmax, offset, alpha);
rewriter.replaceOpWithNewOp<Torch::AtenAbsOp>(binder.op, resultType,
sub);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenArgmaxOp>(
binder.op, resultType, operand, constAxis, constKeepDims);
return success();

View File

@ -74,6 +74,24 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2
// -----
// CHECK-LABEL: @test_argmax_negative_axis_keepdims_random_select_last_index
func.func @test_argmax_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,4],f32>
// CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64>
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64>
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64>
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64>
return %0 : !torch.vtensor<[2,3,1],si64>
}
// -----
// CHECK-LABEL: @test_argmax_no_keepdims_example
func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 1
@ -85,6 +103,24 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) ->
// -----
// CHECK-LABEL: @test_argmax_no_keepdims_random_select_last_index
func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,4],f32>
// CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4],si64>
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C2]], %[[C1_1]] : !torch.vtensor<[2,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,4],si64> -> !torch.vtensor<[2,4],si64>
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64>
return %0 : !torch.vtensor<[2,4],si64>
}
// -----
// CHECK-LABEL: @test_argmin_default_axis_example
func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 0

View File

@ -1,15 +1,5 @@
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch
module {
func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// TODO: Unsupported torch.onnx.select_last_index
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64>
return %0 : !torch.vtensor<[2,4],si64>
}
}
// -----
func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// TODO: Unsupported torch.onnx.select_last_index
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}