mirror of https://github.com/llvm/torch-mlir
[torch] Add OnnxToTorch lowering for Onnx.Unique op (#3523)
Adds OnnxToTorch Lowering for the `Onnx.Unique` op.pull/3568/head
parent
a211ccbcff
commit
30c4d2f2b8
|
@ -284,6 +284,16 @@ struct OpBinder {
|
|||
return failure();
|
||||
}
|
||||
|
||||
ParseResult optionalS64IntegerAttr(int64_t &value, StringRef nameSuffix) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
name.append(nameSuffix);
|
||||
auto attr = op->getAttr(name);
|
||||
if (!attr) {
|
||||
return failure();
|
||||
}
|
||||
return s64IntegerAttr(value, nameSuffix);
|
||||
}
|
||||
|
||||
ParseResult f32FloatAttr(float &value, StringRef nameSuffix,
|
||||
float defaultValue = 0.0f) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
|
|
|
@ -12784,6 +12784,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUniqueDimOp : Torch_Op<"aten.unique_dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
Torch_BoolType:$sorted,
|
||||
Torch_BoolType:$return_inverse,
|
||||
Torch_BoolType:$return_counts
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result0,
|
||||
AnyTorchOptionalTensorType:$result1,
|
||||
AnyTorchOptionalTensorType:$result2
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUniqueDimOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 3);
|
||||
}
|
||||
void AtenUniqueDimOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -4095,4 +4095,122 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, "Handling of this kind of inputs is not there");
|
||||
}
|
||||
});
|
||||
patterns.onOp(
|
||||
"Unique", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value input;
|
||||
int64_t axis, sorted;
|
||||
SmallVector<Type> resultTypes;
|
||||
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.s64IntegerAttr(sorted, "sorted", 1) ||
|
||||
binder.tensorResultTypes(resultTypes))
|
||||
return failure();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
|
||||
|
||||
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||
if (!inputTy.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input type to have sizes");
|
||||
}
|
||||
auto inputShape = inputTy.getSizes();
|
||||
int64_t inputDim = static_cast<int64_t>(inputShape.size());
|
||||
|
||||
Value axisVal;
|
||||
SmallVector<int64_t> outputTensorSizes(inputDim);
|
||||
bool axisWasNone;
|
||||
if (!binder.optionalS64IntegerAttr(axis, "axis")) {
|
||||
if (axis < -1 * inputDim || axis > inputDim - 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"invalid value for axis");
|
||||
axisVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
||||
axisWasNone = false;
|
||||
} else {
|
||||
axisVal = zero;
|
||||
axisWasNone = true;
|
||||
}
|
||||
|
||||
Value sortedVal = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(sorted));
|
||||
Value trueVal =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
|
||||
// The shape of inverse_indices is the same as input shape, but
|
||||
// resulTypes[2] must be used to avoid live value after conversion.
|
||||
Torch::ValueTensorType outputTy;
|
||||
outputTy = cast<Torch::ValueTensorType>(resultTypes[0]);
|
||||
Torch::ValueTensorType countsTy =
|
||||
cast<Torch::ValueTensorType>(resultTypes[3]);
|
||||
Torch::ValueTensorType inverseTy =
|
||||
cast<Torch::ValueTensorType>(resultTypes[2]);
|
||||
|
||||
if (axisWasNone) {
|
||||
int64_t inputNumel = 1;
|
||||
for (auto elem : inputShape) {
|
||||
if (elem == Torch::kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Expected all sizes in input shape to be statically known");
|
||||
}
|
||||
inputNumel *= elem;
|
||||
}
|
||||
auto flattenResultTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef({inputNumel}), inputTy.getDtype());
|
||||
Value negativeOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), -1);
|
||||
input = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
binder.getLoc(), flattenResultTy, input, zero, negativeOne);
|
||||
}
|
||||
|
||||
Torch::AtenUniqueDimOp intermResults =
|
||||
rewriter.create<Torch::AtenUniqueDimOp>(
|
||||
binder.getLoc(), outputTy, inverseTy, countsTy, input, axisVal,
|
||||
sortedVal, trueVal, trueVal);
|
||||
|
||||
SmallVector<Value> uniqueResults = intermResults.getResults();
|
||||
|
||||
// Calculate the indices where each of the unique elements first
|
||||
// appeared in the original input tensor. Also, the counts tensor and
|
||||
// the indices tensor have the same Dtype, int64, so reuse that here.
|
||||
auto arangeResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({inputShape[0]}), countsTy.getOptionalDtype());
|
||||
|
||||
Value inputDimZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[0]));
|
||||
Value int64Type = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(4));
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
Value perm = rewriter.create<Torch::AtenArangeOp>(
|
||||
binder.getLoc(), arangeResultType, inputDimZero,
|
||||
/*dtype=*/int64Type,
|
||||
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
|
||||
|
||||
// Inverse has the same shape as input, but the dtype is not the same.
|
||||
Value flipDims = createConstantIntList(binder, rewriter, {0});
|
||||
Value inverse = rewriter.create<Torch::AtenFlipOp>(
|
||||
binder.getLoc(),
|
||||
inputTy.getWithSizesAndDtype(inputShape, countsTy.getDtype()),
|
||||
uniqueResults[1], flipDims);
|
||||
perm = rewriter.create<Torch::AtenFlipOp>(
|
||||
binder.getLoc(), cast<Torch::ValueTensorType>(perm.getType()), perm,
|
||||
flipDims);
|
||||
|
||||
auto newInverseTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({outputTy.getSizes()[0]}), countsTy.getDtype());
|
||||
Value newInverseSize =
|
||||
createConstantIntList(binder, rewriter, {outputTy.getSizes()[0]});
|
||||
Value newInverse = rewriter.create<Torch::AtenNewEmptyOp>(
|
||||
binder.getLoc(), newInverseTy, inverse, newInverseSize,
|
||||
/*dtype=*/int64Type, /*layout=*/noneVal, /*device=*/noneVal,
|
||||
/*pin_memory=*/noneVal);
|
||||
|
||||
Value firstOccurIndices = rewriter.create<Torch::AtenScatterSrcOp>(
|
||||
binder.getLoc(), resultTypes[1], newInverse, zero, inverse, perm);
|
||||
|
||||
rewriter.replaceOp(binder.op, {uniqueResults[0], firstOccurIndices,
|
||||
uniqueResults[1], uniqueResults[2]});
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -936,6 +936,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -3161,3 +3161,136 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to
|
|||
%0 = torch.operator "onnx.SplitToSequence"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4,6],f32>, !torch.vtensor<[2],si64>) -> !torch.list<vtensor<[2,6],f32>>
|
||||
return %0 : !torch.list<vtensor<[2,6],f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unique_not_sorted_without_axis
|
||||
func.func @test_unique_not_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
|
||||
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64>
|
||||
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.sorted = 0 : si64} : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unique_sorted_without_axis
|
||||
func.func @test_unique_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32>
|
||||
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list<int> -> !torch.vtensor<[6],si64>
|
||||
// CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64>
|
||||
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
%0:4 = torch.operator "onnx.Unique"(%arg0) : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unique_sorted_with_axis_3d
|
||||
func.func @test_unique_sorted_with_axis_3d(%arg0: !torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT1]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[2,4,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[INTO_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INTO_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[4],si64>, !torch.list<int> -> !torch.vtensor<[2,4,2],si64>
|
||||
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[2,4,2],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[2,4,2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[3],si64>
|
||||
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
|
||||
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @test_unique_sorted_with_axis
|
||||
func.func @test_unique_sorted_with_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT0_1]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3,3],si64>
|
||||
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
|
||||
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unique_sorted_with_negative_axis
|
||||
func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[NEGATIVEONE]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3,3],si64>
|
||||
// CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list<int> -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64>
|
||||
// CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64>
|
||||
// CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
%0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>)
|
||||
return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue