diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index b62f9dbaf..8e9de1ff5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -10,11 +10,26 @@ #ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H #define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +class Endian { +private: + static constexpr uint32_t uint32_ = 0x01020304; + static constexpr uint8_t magic_ = (const uint8_t &)uint32_; + +public: + static constexpr bool little = magic_ == 0x04; + static constexpr bool big = magic_ == 0x01; + static_assert(little || big, "Cannot determine endianness!"); + +private: + Endian() = delete; +}; + namespace mlir::torch::onnx_c { Value createConstantIntList(OpBinder binder, @@ -28,6 +43,50 @@ LogicalResult OnnxLstmExpander(OpBinder binder, bool areAllElementsDistinct(SmallVector array); +namespace detail { +/// Matches the constant integers stored in a `onnx.Constant`. +struct onnx_list_of_constant_ints_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + onnx_list_of_constant_ints_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto constOp = dyn_cast(op); + if (!constOp || !constOp.getName().equals("onnx.Constant")) + return false; + + if (DenseResourceElementsAttr attr = + constOp->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + op->emitError("unimplemented: importing on big endian systems"); + return false; + } + + auto ty = cast(attr.getType()); + ElementsAttr denseAttr; + auto ptr = attr.getRawHandle().getBlob()->getData(); + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + for (auto axis : denseAttr.getValues()) { + bind_values.push_back(axis.getSExtValue()); + } + return true; + } + return false; + } +}; +} // namespace detail + +/// Matches the constant integers stored in a `onnx.Constant`. +inline detail::onnx_list_of_constant_ints_op_binder +m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { + return detail::onnx_list_of_constant_ints_op_binder(bind_values); +} + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index e6a9e1622..4508518bf 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -142,7 +142,7 @@ m_TorchConstantBool(bool *bind_value) { } namespace detail { -/// Matches the constant integers stored in a `torch.ListConstruct`. +/// Matches the constant integers stored in a `torch.prim.ListConstruct`. struct torch_list_of_constant_ints_op_binder { SmallVectorImpl &bind_values; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 65bfb6257..7630fcfa1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -661,57 +661,86 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value data; - Value axes; - if (binder.tensorOperands(data, axes) || + SmallVector inputOperands; + if (binder.tensorOperands(inputOperands, binder.op->getNumOperands()) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - if (sizes.size() == 0) { + + Value data = inputOperands[0]; + auto inputType = data.getType().dyn_cast(); + if (!inputType.hasSizes() || !resultType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: expected input and result to have shapes"); + + int64_t inputRank = inputType.getSizes().size(); + int64_t resultRank = resultType.getSizes().size(); + int64_t rankDiff = inputRank - resultRank; + if (rankDiff == 0) { + // In this case, no dimension is squeezed. Hence just replace the op + // with input. + rewriter.replaceOp(binder.op, data); + return success(); + } + + if (inputOperands.size() == 1) { + // Case: `axes` value is not present which means squeeze all the + // dimensions with shape value 1. rewriter.replaceOpWithNewOp(binder.op, resultType, data); return success(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( + + SmallVector dimList; + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + // If the input shape and result shape is statically known then the + // list of dims to be squeezed can be derived from those shapes. As a + // result, we don't have to wait for the dim values to be known at + // runtime which is also expected by the downstream pipeline. + SmallVector inputShape(inputType.getSizes()); + SmallVector resultShape(resultType.getSizes()); + SmallVector squeezeDims; + unsigned resultShapeCounter = 0; + for (unsigned i = 0; i < inputRank; i++) { + if (resultShapeCounter < resultRank && + inputShape[i] == resultShape[resultShapeCounter]) { + resultShapeCounter++; + } else { + squeezeDims.push_back(i); + } + } + for (auto i : squeezeDims) { + dimList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } + + if (dimList.empty()) { + Value axes = inputOperands[1]; + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < rankDiff; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } } Value dimValueList = rewriter.create( binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + rewriter.getType( + rewriter.getType()), dimList); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList); @@ -725,103 +754,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // discussion can be found here: // https://github.com/pytorch/pytorch/issues/9410 // So, for now, we unroll into multiple unsqueezes. + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; - Value data; - Value axes; + Value data, axes; if (binder.tensorOperands(data, axes) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - if (sizes.size() == 0) { + auto inputType = data.getType().dyn_cast(); + if (!inputType.hasSizes() || !resultType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: expected input and result to have shapes"); + + int64_t inputRank = inputType.getSizes().size(); + int64_t resultRank = resultType.getSizes().size(); + int64_t rankDiff = resultRank - inputRank; + if (rankDiff == 0) { + // In this case, no dimension is unsqueezed. Hence just replace the op + // with input. rewriter.replaceOp(binder.op, data); return success(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); - } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value noneVal = rewriter.create(binder.getLoc()); - Value updatedAxes = rewriter.create( - binder.getLoc(), - axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()), - dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse); - // Sort the list of dims, so we don't run into this situation: - // data.sizes = [2, 3, 4] - // dims = [4, 0] - // index 4 will be invalid to add a singleton dimension because - // data.sizes.size == 3 We have to work with sorted dims to avoid this - // situation. - auto sortIndicesType = axesType.getWithSizesAndDtype( - axesType.getOptionalSizes(), - IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed)); - auto sortOpResult = rewriter.create( - binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero, - cstFalse); - Value result; - auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( - binder.op->getContext()); - // Go through the updated, sorted axes. Do unsqueeze for each dim. - for (int i = 0; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, sortOpResult->getResult(0), - zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - if (sizes[0] == 1) { - result = rewriter.create( - binder.getLoc(), resultType, data, dim); - } else if (i == 0) { - result = rewriter.create( - binder.getLoc(), baseType, data, dim); - } else if (i == sizes[0] - 1) { - result = rewriter.create( - binder.getLoc(), resultType, result, dim); - } else { - result = rewriter.create( - binder.getLoc(), baseType, result, dim); + + SmallVector unsqueezeDims; + SmallVector inputShape(inputType.getSizes()); + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + // If the input shape and result shape is statically known then the + // list of dims to be squeezed can be derived from those shapes. As a + // result, we don't have to wait for the dim values to be known at + // runtime which is also expected by the downstream pipeline. + SmallVector resultShape(resultType.getSizes()); + unsigned inputShapeCounter = 0; + for (unsigned i = 0; i < resultRank; i++) { + if (inputShapeCounter < inputRank && + inputShape[inputShapeCounter] == resultShape[i]) { + inputShapeCounter++; + } else { + unsqueezeDims.push_back(i); + } } + } else { + SmallVector unsqueezeDimsInts; + if (!matchPattern(axes, m_OnnxListOfConstantInts(unsqueezeDimsInts))) + return rewriter.notifyMatchFailure( + binder.op, "only support constant int axes values"); + + for (auto dim : unsqueezeDimsInts) + unsqueezeDims.push_back(dim < 0 ? dim + resultRank : dim); + // If we don't sort, unsqueezing first on 4 and then on 0 would fail + // for shape = {x,y,z}, and axes [4,0] + llvm::sort(unsqueezeDims.begin(), unsqueezeDims.end()); + } + Value result = data; + SmallVector unsqueezeShape = inputShape; + for (auto dim : unsqueezeDims) { + unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1); + Type unsqueezeType = resultType.getWithSizesAndDtype( + unsqueezeShape, resultType.getOptionalDtype()); + Value cstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + result = rewriter.create(loc, unsqueezeType, + result, cstDim); } rewriter.replaceOp(binder.op, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0607a9720..64a9d3bb6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2643,12 +2643,8 @@ ONNX_XFAIL_SET = { # Failure - onnx_lowering: onnx.ScatterElements "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", - "ScatterReduceFloatProdModuleIncludeSelf", - "ScatterReduceFloatSumModuleIncludeSelf", "ScatterReduceIntMaxModuleIncludeSelf", "ScatterReduceIntMinModuleIncludeSelf", - "ScatterReduceIntProdModuleIncludeSelf", - "ScatterReduceIntSumModuleIncludeSelf", "ScatterValueFloatModule_basic", # Failure - onnx_lowering: onnx.ScatterND @@ -2680,22 +2676,12 @@ ONNX_XFAIL_SET = { # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - - # Failure - onnx_lowering: onnx.Squeeze - "SqueezeModule_allUnitDim", - "SqueezeModule_broadcast", - "SqueezeModule_static", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", - + # Failure - unknown "BernoulliModule_basic", - "BucketizeTensorFloatModule_basic", - "BucketizeTensorModule_basic", - "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", @@ -2712,22 +2698,16 @@ ONNX_XFAIL_SET = { "ElementwiseErfIntModule_basic", "ElementwiseExpIntModule_basic", "ElementwiseLogIntModule_basic", - "ElementwisePreluModule_basic", - "ElementwisePreluStaticModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", - "ElementwiseUnsqueezeNegDimsModule_basic", - "GroupNormModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", } if torch_version_for_comparison() >= version.parse("2.4.0.dev"): @@ -2746,6 +2726,10 @@ if torch_version_for_comparison() < version.parse('2.3.0.dev'): ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - + "ElementwisePreluModule_basic", "ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModuleIncludeSelf", } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 47497d5ea..de3e796f4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -424,19 +424,34 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @test_squeeze_no_axes +func.func @test_squeeze_no_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.squeeze %arg0 : !torch.vtensor<[1,3,1,4,1,5,1,1],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_squeeze_five_axes +func.func @test_squeeze_five_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[INT7:.*]] = torch.constant.int 7 + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT4]], %[[INT6]], %[[INT7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.list -> !torch.vtensor<[3,1,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> + return %0 : !torch.vtensor<[3,1,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_squeeze func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT4:.*]] = torch.constant.int 4 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -445,24 +460,10 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten // CHECK-LABEL: func.func @test_squeeze_two_axes func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT5:.*]] = torch.constant.int 5 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -472,23 +473,7 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.constant.bool false - // CHECK: torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> + // CHECK: torch.aten.unsqueeze %arg0, %[[INT0:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> return %0 : !torch.vtensor<[1,3,4,5],f32> } @@ -497,24 +482,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor // CHECK-LABEL: func.func @test_unsqueeze_axis_1 func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.unsqueeze %arg0, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> return %0 : !torch.vtensor<[3,1,4,5],f32> } @@ -523,146 +492,22 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor // CHECK-LABEL: func.func @test_unsqueeze_axis_2 func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> return %0 : !torch.vtensor<[3,4,1,5],f32> } // ----- -// CHECK-LABEL: func.func @test_unsqueeze_negative_axes -func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT4:.*]] = torch.constant.int 4 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32> - %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> - return %0 : !torch.vtensor<[1,3,1,1,5],f32> -} - -// ----- - // CHECK-LABEL: func.func @test_unsqueeze_three_axes func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> - // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor - // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor - // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> - %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> - return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes -func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> - // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> - // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor - // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor - // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + // CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[UNSQUEEZE_1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE]], %[[INT4]] : !torch.vtensor<[3,4,1,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1],f32> + // CHECK: %[[INT5:.*]] = torch.constant.int 5 + // CHECK: torch.aten.unsqueeze %[[UNSQUEEZE_1]], %[[INT5]] : !torch.vtensor<[3,4,1,5,1],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> }