From 4df96616dba72400071535c75188d94df7e44184 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 8 Feb 2024 07:14:07 +0530 Subject: [PATCH] [MLIR][TORCH] Modify Onnx.Reshape lowering for static shape cases (#2852) This commit modifies the OnnxToTorch lowering of Onnx.Reshape op by creating the result shape list for the aten.reshape using the result shape values inferred from the op's result shape. Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 23 +++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 163 +++--------------- 2 files changed, 43 insertions(+), 143 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 8227514b5..764cfc247 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1656,6 +1656,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType) || binder.s64IntegerAttr(allowzero, "allowzero", 0)) return failure(); + + // If the result shape is static then we can create a result shape list + // directly using the result shape values (integers). + if (resultType.hasSizes()) { + bool hasStaticShape = resultType.areAllSizesKnown(); + ArrayRef resultShapeInt = resultType.getSizes(); + if (hasStaticShape) { + SmallVector resultShape; + for (int64_t dim : resultShapeInt) { + resultShape.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + } + Value resultShapeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + resultShape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, resultShapeList); + return success(); + } + } + Torch::BaseTensorType shapeType = shape.getType().cast(); SmallVector dimList; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ae36661bd..8a5d5b1ef 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1256,33 +1256,11 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_negative_dim func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2_0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> return %0 : !torch.vtensor<[2,6,2],f32> } @@ -1291,40 +1269,12 @@ func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_negative_extended_dims func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT2]], %[[INT3]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> return %0 : !torch.vtensor<[1,2,3,4],f32> } @@ -1333,17 +1283,9 @@ func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32 // CHECK-LABEL: func.func @test_reshape_one_dim func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %6 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> + // CHECK: %[[INT24:.+]] = torch.constant.int 24 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT24]] : (!torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> return %0 : !torch.vtensor<[24],f32> } @@ -1352,25 +1294,10 @@ func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torc // CHECK-LABEL: func.func @test_reshape_reduced_dims func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %12 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> + // CHECK: %[[INT12:.+]] = torch.constant.int 12 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT12]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> return %0 : !torch.vtensor<[2,12],f32> } @@ -1379,33 +1306,11 @@ func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_reordered_all_dims func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT4]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> return %0 : !torch.vtensor<[4,2,3],f32> } @@ -1414,40 +1319,12 @@ func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, % // CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]], %[[INT1]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32> }