mirror of https://github.com/llvm/torch-mlir
Support select_last_index attribute of onnx argmax op (#3192)
The tests listed in https://github.com/nod-ai/SHARK-Turbine/issues/635 all compiled, but having run issue of dtype mismatch of i/si.pull/3212/head
parent
ddb29c2c02
commit
61e6312c87
|
@ -101,17 +101,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.s64BoolAttr(selectLastIndex, "select_last_index", false))
|
||||
return failure();
|
||||
|
||||
if (selectLastIndex) {
|
||||
// TODO: Figure out how to support this case. Need to add a reverse
|
||||
// or something.
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: select_last_index=true");
|
||||
}
|
||||
|
||||
// ONNX allows negative axis.
|
||||
auto operandSizes =
|
||||
cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
||||
if (axis < 0)
|
||||
axis +=
|
||||
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
|
||||
axis += operandSizes.size();
|
||||
|
||||
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
|
@ -119,6 +113,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
Value constKeepDims = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(keepDims));
|
||||
|
||||
if (selectLastIndex) {
|
||||
Value dims = createConstantIntList(binder, rewriter, {axis});
|
||||
auto operandTy = dyn_cast<Torch::ValueTensorType>(operand.getType());
|
||||
operand = rewriter.create<Torch::AtenFlipOp>(
|
||||
binder.getLoc(), operandTy, operand, dims);
|
||||
Value argmax = rewriter.create<Torch::AtenArgmaxOp>(
|
||||
binder.getLoc(), resultType, operand, constAxis, constKeepDims);
|
||||
Value offset = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(operandSizes[axis] - 1));
|
||||
Value alpha = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
|
||||
binder.getLoc(), resultType, argmax, offset, alpha);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAbsOp>(binder.op, resultType,
|
||||
sub);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenArgmaxOp>(
|
||||
binder.op, resultType, operand, constAxis, constKeepDims);
|
||||
return success();
|
||||
|
|
|
@ -74,6 +74,24 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_argmax_negative_axis_keepdims_random_select_last_index
|
||||
func.func @test_argmax_negative_axis_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C2_0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,4],f32>
|
||||
// CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C2]], %[[TRUE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,3,1],si64>
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C3]], %[[C1]] : !torch.vtensor<[2,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,3,1],si64>
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,3,1],si64> -> !torch.vtensor<[2,3,1],si64>
|
||||
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,1],si64>
|
||||
return %0 : !torch.vtensor<[2,3,1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_argmax_no_keepdims_example
|
||||
func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT:.*]] = torch.constant.int 1
|
||||
|
@ -85,6 +103,24 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) ->
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_argmax_no_keepdims_random_select_last_index
|
||||
func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,4],f32>
|
||||
// CHECK: %[[ARGMAX:.*]] = torch.aten.argmax %[[FLIP]], %[[C1]], %[[FALSE]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,4],si64>
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Scalar %[[ARGMAX]], %[[C2]], %[[C1_1]] : !torch.vtensor<[2,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %[[SUB]] : !torch.vtensor<[2,4],si64> -> !torch.vtensor<[2,4],si64>
|
||||
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64>
|
||||
return %0 : !torch.vtensor<[2,4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_argmin_default_axis_example
|
||||
func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT:.*]] = torch.constant.int 0
|
||||
|
|
|
@ -1,15 +1,5 @@
|
|||
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch
|
||||
|
||||
module {
|
||||
func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// TODO: Unsupported torch.onnx.select_last_index
|
||||
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
|
||||
%0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64>
|
||||
return %0 : !torch.vtensor<[2,4],si64>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// TODO: Unsupported torch.onnx.select_last_index
|
||||
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
|
||||
|
|
Loading…
Reference in New Issue