mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Modify Onnx.Reshape lowering for static shape cases (#2852)
This commit modifies the OnnxToTorch lowering of Onnx.Reshape op by creating the result shape list for the aten.reshape using the result shape values inferred from the op's result shape. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2891/head
parent
a8aad2a5ab
commit
4df96616db
|
@ -1656,6 +1656,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(allowzero, "allowzero", 0))
|
||||
return failure();
|
||||
|
||||
// If the result shape is static then we can create a result shape list
|
||||
// directly using the result shape values (integers).
|
||||
if (resultType.hasSizes()) {
|
||||
bool hasStaticShape = resultType.areAllSizesKnown();
|
||||
ArrayRef<int64_t> resultShapeInt = resultType.getSizes();
|
||||
if (hasStaticShape) {
|
||||
SmallVector<Value> resultShape;
|
||||
for (int64_t dim : resultShapeInt) {
|
||||
resultShape.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
|
||||
}
|
||||
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
resultShape);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
||||
binder.op, resultType, data, resultShapeList);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
Torch::BaseTensorType shapeType =
|
||||
shape.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<Value> dimList;
|
||||
|
|
|
@ -1256,33 +1256,11 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1:
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_negative_dim
|
||||
func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT4:.+]] = torch.constant.int 4
|
||||
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
|
||||
// CHECK: %[[INT6:.+]] = torch.constant.int 6
|
||||
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2_0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,6,2],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32>
|
||||
return %0 : !torch.vtensor<[2,6,2],f32>
|
||||
}
|
||||
|
@ -1291,40 +1269,12 @@ func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1:
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_negative_extended_dims
|
||||
func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT4:.+]] = torch.constant.int 4
|
||||
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT3_2:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT2]], %[[INT3]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32>
|
||||
return %0 : !torch.vtensor<[1,2,3,4],f32>
|
||||
}
|
||||
|
@ -1333,17 +1283,9 @@ func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_one_dim
|
||||
func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %6 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[24],f32>
|
||||
// CHECK: %[[INT24:.+]] = torch.constant.int 24
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT24]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[24],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32>
|
||||
return %0 : !torch.vtensor<[24],f32>
|
||||
}
|
||||
|
@ -1352,25 +1294,10 @@ func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torc
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_reduced_dims
|
||||
func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %12 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,12],f32>
|
||||
// CHECK: %[[INT12:.+]] = torch.constant.int 12
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT12]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,12],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32>
|
||||
return %0 : !torch.vtensor<[2,12],f32>
|
||||
}
|
||||
|
@ -1379,33 +1306,11 @@ func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1:
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_reordered_all_dims
|
||||
func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT4:.+]] = torch.constant.int 4
|
||||
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[4,2,3],f32>
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT4]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[4,2,3],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32>
|
||||
return %0 : !torch.vtensor<[4,2,3],f32>
|
||||
}
|
||||
|
@ -1414,40 +1319,12 @@ func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %
|
|||
|
||||
// CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim
|
||||
func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT4:.+]] = torch.constant.int 4
|
||||
// CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT3_2:.+]] = torch.constant.int 3
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1,4],f32>
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]], %[[INT1]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1,4],f32>
|
||||
%0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32>
|
||||
return %0 : !torch.vtensor<[2,3,1,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue