[torch] Add OnnxToTorch lowering for Onnx.Unique op (#3523)

Adds OnnxToTorch Lowering for the `Onnx.Unique` op.
pull/3568/head
Vinayak Dev 2024-07-29 17:32:44 +05:30 committed by GitHub
parent a211ccbcff
commit 30c4d2f2b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 293 additions and 0 deletions

View File

@ -284,6 +284,16 @@ struct OpBinder {
return failure();
}
ParseResult optionalS64IntegerAttr(int64_t &value, StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
return failure();
}
return s64IntegerAttr(value, nameSuffix);
}
ParseResult f32FloatAttr(float &value, StringRef nameSuffix,
float defaultValue = 0.0f) {
SmallString<64> name("torch.onnx.");

View File

@ -12784,6 +12784,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
}];
}
def Torch_AtenUniqueDimOp : Torch_Op<"aten.unique_dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$sorted,
Torch_BoolType:$return_inverse,
Torch_BoolType:$return_counts
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1,
AnyTorchOptionalTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUniqueDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 3);
}
void AtenUniqueDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 3);
}
}];
}
def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -4095,4 +4095,122 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "Handling of this kind of inputs is not there");
}
});
patterns.onOp(
"Unique", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value input;
int64_t axis, sorted;
SmallVector<Type> resultTypes;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(sorted, "sorted", 1) ||
binder.tensorResultTypes(resultTypes))
return failure();
Value zero = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
if (!inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type to have sizes");
}
auto inputShape = inputTy.getSizes();
int64_t inputDim = static_cast<int64_t>(inputShape.size());
Value axisVal;
SmallVector<int64_t> outputTensorSizes(inputDim);
bool axisWasNone;
if (!binder.optionalS64IntegerAttr(axis, "axis")) {
if (axis < -1 * inputDim || axis > inputDim - 1)
return rewriter.notifyMatchFailure(binder.op,
"invalid value for axis");
axisVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
axisWasNone = false;
} else {
axisVal = zero;
axisWasNone = true;
}
Value sortedVal = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(sorted));
Value trueVal =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
// The shape of inverse_indices is the same as input shape, but
// resulTypes[2] must be used to avoid live value after conversion.
Torch::ValueTensorType outputTy;
outputTy = cast<Torch::ValueTensorType>(resultTypes[0]);
Torch::ValueTensorType countsTy =
cast<Torch::ValueTensorType>(resultTypes[3]);
Torch::ValueTensorType inverseTy =
cast<Torch::ValueTensorType>(resultTypes[2]);
if (axisWasNone) {
int64_t inputNumel = 1;
for (auto elem : inputShape) {
if (elem == Torch::kUnknownSize) {
return rewriter.notifyMatchFailure(
binder.op,
"Expected all sizes in input shape to be statically known");
}
inputNumel *= elem;
}
auto flattenResultTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef({inputNumel}), inputTy.getDtype());
Value negativeOne =
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), -1);
input = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
binder.getLoc(), flattenResultTy, input, zero, negativeOne);
}
Torch::AtenUniqueDimOp intermResults =
rewriter.create<Torch::AtenUniqueDimOp>(
binder.getLoc(), outputTy, inverseTy, countsTy, input, axisVal,
sortedVal, trueVal, trueVal);
SmallVector<Value> uniqueResults = intermResults.getResults();
// Calculate the indices where each of the unique elements first
// appeared in the original input tensor. Also, the counts tensor and
// the indices tensor have the same Dtype, int64, so reuse that here.
auto arangeResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({inputShape[0]}), countsTy.getOptionalDtype());
Value inputDimZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[0]));
Value int64Type = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(4));
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value perm = rewriter.create<Torch::AtenArangeOp>(
binder.getLoc(), arangeResultType, inputDimZero,
/*dtype=*/int64Type,
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
// Inverse has the same shape as input, but the dtype is not the same.
Value flipDims = createConstantIntList(binder, rewriter, {0});
Value inverse = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(),
inputTy.getWithSizesAndDtype(inputShape, countsTy.getDtype()),
uniqueResults[1], flipDims);
perm = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), cast<Torch::ValueTensorType>(perm.getType()), perm,
flipDims);
auto newInverseTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({outputTy.getSizes()[0]}), countsTy.getDtype());
Value newInverseSize =
createConstantIntList(binder, rewriter, {outputTy.getSizes()[0]});
Value newInverse = rewriter.create<Torch::AtenNewEmptyOp>(
binder.getLoc(), newInverseTy, inverse, newInverseSize,
/*dtype=*/int64Type, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);
Value firstOccurIndices = rewriter.create<Torch::AtenScatterSrcOp>(
binder.getLoc(), resultTypes[1], newInverse, zero, inverse, perm);
rewriter.replaceOp(binder.op, {uniqueResults[0], firstOccurIndices,
uniqueResults[1], uniqueResults[2]});
return success();
});
}

View File

@ -936,6 +936,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
)

View File

@ -3161,3 +3161,136 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to
%0 = torch.operator "onnx.SplitToSequence"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4,6],f32>, !torch.vtensor<[2],si64>) -> !torch.list<vtensor<[2,6],f32>>
return %0 : !torch.list<vtensor<[2,6],f32>>
}
// -----
// CHECK-LABEL: func.func @test_unique_not_sorted_without_axis
func.func @test_unique_not_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
// CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64>
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.sorted = 0 : si64} : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: func.func @test_unique_sorted_without_axis
func.func @test_unique_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
// CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64>
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
%0:4 = torch.operator "onnx.Unique"(%arg0) : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: func.func @test_unique_sorted_with_axis_3d
func.func @test_unique_sorted_with_axis_3d(%arg0: !torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT1]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[2,4,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 2
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
// CHECK: %[[INTO_1:.*]] = torch.constant.int 0
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INTO_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[4],si64>, !torch.list<int> -> !torch.vtensor<[2,4,2],si64>
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[2,4,2],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[2,4,2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[3],si64>
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
}
// -----
// CHECK-LABEL: func.func @test_unique_sorted_with_axis
func.func @test_unique_sorted_with_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT0_1]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
// CHECK: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3,3],si64>
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3],si64>
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
}
// -----
// CHECK-LABEL: func.func @test_unique_sorted_with_negative_axis
func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[NEGATIVEONE]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3,3],si64>
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3],si64>
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
}