mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx-to-torch` lowering for flatten shape (#2834)
The existing `flatten` lowering did not define what the intermediate shape was. This could result in failures to lower further to linalg as the intermediate shape was unknown. Added a shape refinement section.pull/2870/head
parent
b3a56c0711
commit
cb52c4b3cc
|
@ -501,7 +501,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
auto message = llvm::formatv("unimplemented support for the given "
|
auto message = llvm::formatv("unimplemented support for the given "
|
||||||
"dtype conversion (onnx 'type' = {0})",
|
"dtype conversion (onnx 'type' = {0})",
|
||||||
dtypeIntOnnx);
|
dtypeIntOnnx);
|
||||||
llvm::errs() << message << "\n";
|
|
||||||
auto y = rewriter.notifyMatchFailure(binder.op, message);
|
auto y = rewriter.notifyMatchFailure(binder.op, message);
|
||||||
|
|
||||||
return y;
|
return y;
|
||||||
|
@ -1444,16 +1443,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
|
llvm::SmallVector<int64_t> shape(operandTy.getSizes());
|
||||||
|
int64_t rank = shape.size();
|
||||||
|
|
||||||
// If axis is negative, count from the right instead of left
|
// If axis is negative, count from the right instead of left
|
||||||
int64_t rank =
|
|
||||||
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
|
|
||||||
if (axis < 0)
|
if (axis < 0)
|
||||||
axis = rank + axis;
|
axis = rank + axis;
|
||||||
|
|
||||||
Value collapsedRight;
|
// We collapse in the dimensions to the right of the axis.
|
||||||
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
for (int i = axis + 1; i < rank; ++i) {
|
||||||
binder.op->getContext());
|
bool dynamic = shape[axis] == Torch::kUnknownSize ||
|
||||||
|
shape[i] == Torch::kUnknownSize;
|
||||||
|
if (dynamic) {
|
||||||
|
shape[axis] = Torch::kUnknownSize;
|
||||||
|
} else {
|
||||||
|
shape[axis] = shape[axis] * shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shape.resize(axis + 1, 1);
|
||||||
|
|
||||||
|
auto baseType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
shape, operandTy.getDtype());
|
||||||
|
Value collapsedRight;
|
||||||
if (axis >= rank) {
|
if (axis >= rank) {
|
||||||
// If the right range is empty, add a dim of size 1 to the
|
// If the right range is empty, add a dim of size 1 to the
|
||||||
// right side of the shape:
|
// right side of the shape:
|
||||||
|
|
|
@ -1311,23 +1311,23 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s
|
||||||
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
|
||||||
return %0 : !torch.vtensor<[6,20],f32>
|
return %0 : !torch.vtensor<[6,20],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_flatten_4d_axis_0
|
// // CHECK-LABEL: @test_flatten_4d_axis_0
|
||||||
func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32>
|
||||||
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
|
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
|
||||||
return %0 : !torch.vtensor<[1,120],f32>
|
return %0 : !torch.vtensor<[1,120],f32>
|
||||||
}
|
}
|
||||||
|
@ -1337,10 +1337,10 @@ func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc
|
||||||
// CHECK-LABEL: @test_flatten_4d_axis_4
|
// CHECK-LABEL: @test_flatten_4d_axis_4
|
||||||
func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
|
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor<[2,3,4,5,1],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32>
|
||||||
return %0 : !torch.vtensor<[120,1],f32>
|
return %0 : !torch.vtensor<[120,1],f32>
|
||||||
}
|
}
|
||||||
|
@ -1351,10 +1351,10 @@ func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc
|
||||||
func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
|
||||||
return %0 : !torch.vtensor<[6,20],f32>
|
return %0 : !torch.vtensor<[6,20],f32>
|
||||||
}
|
}
|
||||||
|
@ -1365,10 +1365,10 @@ func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>)
|
||||||
func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,5],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32>
|
||||||
return %0 : !torch.vtensor<[24,5],f32>
|
return %0 : !torch.vtensor<[24,5],f32>
|
||||||
}
|
}
|
||||||
|
@ -1379,9 +1379,9 @@ func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>)
|
||||||
func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32>
|
||||||
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
|
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
|
||||||
return %0 : !torch.vtensor<[1,120],f32>
|
return %0 : !torch.vtensor<[1,120],f32>
|
||||||
}
|
}
|
||||||
|
@ -1392,10 +1392,10 @@ func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>)
|
||||||
func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
|
||||||
return %0 : !torch.vtensor<[2,3],f32>
|
return %0 : !torch.vtensor<[2,3],f32>
|
||||||
}
|
}
|
||||||
|
@ -1406,9 +1406,9 @@ func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vt
|
||||||
func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
|
||||||
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
|
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
|
||||||
return %0 : !torch.vtensor<[1,2],f32>
|
return %0 : !torch.vtensor<[1,2],f32>
|
||||||
}
|
}
|
||||||
|
@ -1419,9 +1419,9 @@ func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten
|
||||||
func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
|
||||||
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
|
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
|
||||||
return %0 : !torch.vtensor<[1,2],f32>
|
return %0 : !torch.vtensor<[1,2],f32>
|
||||||
}
|
}
|
||||||
|
@ -1431,10 +1431,10 @@ func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !t
|
||||||
// COM: CHECK-LABEL: @test_flatten_1d_axis_1
|
// COM: CHECK-LABEL: @test_flatten_1d_axis_1
|
||||||
func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
|
||||||
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
|
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2,1],f32>
|
||||||
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
|
||||||
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
|
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
|
||||||
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32>
|
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32>
|
||||||
return %0 : !torch.vtensor<[2,1],f32>
|
return %0 : !torch.vtensor<[2,1],f32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue