[MLIR][TORCH] Fix OnnxToLinalg lowering issue for Squeeze and Unsqueeze op (#2991)

This commit also cleans up the OnnxToTorch lowering for the Squeeze and
Unsqueeze op and adds the support for handling edge cases.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3203/head
Vivek Khandelwal 2024-04-22 14:22:42 +05:30 committed by GitHub
parent e5bdd71baf
commit 6abc7371c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 346 deletions

View File

@ -10,11 +10,26 @@
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
class Endian {
private:
static constexpr uint32_t uint32_ = 0x01020304;
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;
public:
static constexpr bool little = magic_ == 0x04;
static constexpr bool big = magic_ == 0x01;
static_assert(little || big, "Cannot determine endianness!");
private:
Endian() = delete;
};
namespace mlir::torch::onnx_c {
Value createConstantIntList(OpBinder binder,
@ -28,6 +43,50 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
bool areAllElementsDistinct(SmallVector<int64_t> array);
namespace detail {
/// Matches the constant integers stored in a `onnx.Constant`.
struct onnx_list_of_constant_ints_op_binder {
SmallVectorImpl<int64_t> &bind_values;
/// Creates a matcher instance that binds the value to bvs if match succeeds.
onnx_list_of_constant_ints_op_binder(SmallVectorImpl<int64_t> &bvs)
: bind_values(bvs) {}
bool match(Operation *op) {
auto constOp = dyn_cast<Torch::OperatorOp>(op);
if (!constOp || !constOp.getName().equals("onnx.Constant"))
return false;
if (DenseResourceElementsAttr attr =
constOp->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
op->emitError("unimplemented: importing on big endian systems");
return false;
}
auto ty = cast<ShapedType>(attr.getType());
ElementsAttr denseAttr;
auto ptr = attr.getRawHandle().getBlob()->getData();
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
for (auto axis : denseAttr.getValues<llvm::APInt>()) {
bind_values.push_back(axis.getSExtValue());
}
return true;
}
return false;
}
};
} // namespace detail
/// Matches the constant integers stored in a `onnx.Constant`.
inline detail::onnx_list_of_constant_ints_op_binder
m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
}
} // namespace mlir::torch::onnx_c
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

View File

@ -142,7 +142,7 @@ m_TorchConstantBool(bool *bind_value) {
}
namespace detail {
/// Matches the constant integers stored in a `torch.ListConstruct`.
/// Matches the constant integers stored in a `torch.prim.ListConstruct`.
struct torch_list_of_constant_ints_op_binder {
SmallVectorImpl<int64_t> &bind_values;

View File

@ -661,57 +661,86 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
patterns.onOp(
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Value axes;
if (binder.tensorOperands(data, axes) ||
SmallVector<Value> inputOperands;
if (binder.tensorOperands(inputOperands, binder.op->getNumOperands()) ||
binder.tensorResultType(resultType))
return failure();
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = axesType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
auto sizes =
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
if (sizes.size() == 0) {
Value data = inputOperands[0];
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
if (!inputType.hasSizes() || !resultType.hasSizes())
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: expected input and result to have shapes");
int64_t inputRank = inputType.getSizes().size();
int64_t resultRank = resultType.getSizes().size();
int64_t rankDiff = inputRank - resultRank;
if (rankDiff == 0) {
// In this case, no dimension is squeezed. Hence just replace the op
// with input.
rewriter.replaceOp(binder.op, data);
return success();
}
if (inputOperands.size() == 1) {
// Case: `axes` value is not present which means squeeze all the
// dimensions with shape value 1.
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op,
resultType, data);
return success();
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
SmallVector<Value> dimList;
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
// If the input shape and result shape is statically known then the
// list of dims to be squeezed can be derived from those shapes. As a
// result, we don't have to wait for the dim values to be known at
// runtime which is also expected by the downstream pipeline.
SmallVector<int64_t> inputShape(inputType.getSizes());
SmallVector<int64_t> resultShape(resultType.getSizes());
SmallVector<int64_t> squeezeDims;
unsigned resultShapeCounter = 0;
for (unsigned i = 0; i < inputRank; i++) {
if (resultShapeCounter < resultRank &&
inputShape[i] == resultShape[resultShapeCounter]) {
resultShapeCounter++;
} else {
squeezeDims.push_back(i);
}
}
for (auto i : squeezeDims) {
dimList.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
}
if (dimList.empty()) {
Value axes = inputOperands[1];
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
for (int i = 0; i < rankDiff; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
dimList.push_back(dim);
}
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
dimList);
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
binder.op, resultType, data, dimValueList);
@ -725,103 +754,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// discussion can be found here:
// https://github.com/pytorch/pytorch/issues/9410
// So, for now, we unroll into multiple unsqueezes.
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value data;
Value axes;
Value data, axes;
if (binder.tensorOperands(data, axes) ||
binder.tensorResultType(resultType))
return failure();
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = axesType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
auto sizes =
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
if (sizes.size() == 0) {
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
if (!inputType.hasSizes() || !resultType.hasSizes())
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: expected input and result to have shapes");
int64_t inputRank = inputType.getSizes().size();
int64_t resultRank = resultType.getSizes().size();
int64_t rankDiff = resultRank - inputRank;
if (rankDiff == 0) {
// In this case, no dimension is unsqueezed. Hence just replace the op
// with input.
rewriter.replaceOp(binder.op, data);
return success();
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimList);
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value updatedAxes = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(),
axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()),
dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse);
// Sort the list of dims, so we don't run into this situation:
// data.sizes = [2, 3, 4]
// dims = [4, 0]
// index 4 will be invalid to add a singleton dimension because
// data.sizes.size == 3 We have to work with sorted dims to avoid this
// situation.
auto sortIndicesType = axesType.getWithSizesAndDtype(
axesType.getOptionalSizes(),
IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed));
auto sortOpResult = rewriter.create<Torch::AtenSortOp>(
binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero,
cstFalse);
Value result;
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
binder.op->getContext());
// Go through the updated, sorted axes. Do unsqueeze for each dim.
for (int i = 0; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, sortOpResult->getResult(0),
zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
if (sizes[0] == 1) {
result = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), resultType, data, dim);
} else if (i == 0) {
result = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), baseType, data, dim);
} else if (i == sizes[0] - 1) {
result = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), resultType, result, dim);
} else {
result = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), baseType, result, dim);
SmallVector<int64_t> unsqueezeDims;
SmallVector<int64_t> inputShape(inputType.getSizes());
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
// If the input shape and result shape is statically known then the
// list of dims to be squeezed can be derived from those shapes. As a
// result, we don't have to wait for the dim values to be known at
// runtime which is also expected by the downstream pipeline.
SmallVector<int64_t> resultShape(resultType.getSizes());
unsigned inputShapeCounter = 0;
for (unsigned i = 0; i < resultRank; i++) {
if (inputShapeCounter < inputRank &&
inputShape[inputShapeCounter] == resultShape[i]) {
inputShapeCounter++;
} else {
unsqueezeDims.push_back(i);
}
}
} else {
SmallVector<int64_t> unsqueezeDimsInts;
if (!matchPattern(axes, m_OnnxListOfConstantInts(unsqueezeDimsInts)))
return rewriter.notifyMatchFailure(
binder.op, "only support constant int axes values");
for (auto dim : unsqueezeDimsInts)
unsqueezeDims.push_back(dim < 0 ? dim + resultRank : dim);
// If we don't sort, unsqueezing first on 4 and then on 0 would fail
// for shape = {x,y,z}, and axes [4,0]
llvm::sort(unsqueezeDims.begin(), unsqueezeDims.end());
}
Value result = data;
SmallVector<int64_t> unsqueezeShape = inputShape;
for (auto dim : unsqueezeDims) {
unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1);
Type unsqueezeType = resultType.getWithSizesAndDtype(
unsqueezeShape, resultType.getOptionalDtype());
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim));
result = rewriter.create<Torch::AtenUnsqueezeOp>(loc, unsqueezeType,
result, cstDim);
}
rewriter.replaceOp(binder.op, result);
return success();

View File

@ -2643,12 +2643,8 @@ ONNX_XFAIL_SET = {
# Failure - onnx_lowering: onnx.ScatterElements
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMinModuleIncludeSelf",
"ScatterReduceFloatProdModuleIncludeSelf",
"ScatterReduceFloatSumModuleIncludeSelf",
"ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf",
"ScatterReduceIntProdModuleIncludeSelf",
"ScatterReduceIntSumModuleIncludeSelf",
"ScatterValueFloatModule_basic",
# Failure - onnx_lowering: onnx.ScatterND
@ -2680,22 +2676,12 @@ ONNX_XFAIL_SET = {
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
# Failure - onnx_lowering: onnx.Squeeze
"SqueezeModule_allUnitDim",
"SqueezeModule_broadcast",
"SqueezeModule_static",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Failure - unknown
"BernoulliModule_basic",
"BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic",
"BucketizeTensorOutInt32RightModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",
@ -2712,22 +2698,16 @@ ONNX_XFAIL_SET = {
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"GroupNormModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
}
if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
@ -2746,6 +2726,10 @@ if torch_version_for_comparison() < version.parse('2.3.0.dev'):
ONNX_CRASHING_SET = {
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"ElementwisePreluModule_basic",
"ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic",
"ScatterReduceFloatProdModuleIncludeSelf",
"ScatterReduceFloatSumModuleIncludeSelf",
"ScatterReduceIntProdModuleIncludeSelf",
"ScatterReduceIntSumModuleIncludeSelf",
}

View File

@ -424,19 +424,34 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch
// -----
// CHECK-LABEL: func.func @test_squeeze_no_axes
func.func @test_squeeze_no_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.squeeze %arg0 : !torch.vtensor<[1,3,1,4,1,5,1,1],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Squeeze"(%arg0) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_squeeze_five_axes
func.func @test_squeeze_five_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT7:.*]] = torch.constant.int 7
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT4]], %[[INT6]], %[[INT7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.list<int> -> !torch.vtensor<[3,1,4,5],f32>
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32>
return %0 : !torch.vtensor<[3,1,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_squeeze
func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
@ -445,24 +460,10 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten
// CHECK-LABEL: func.func @test_squeeze_two_axes
func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
@ -472,23 +473,7 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1:
// CHECK-LABEL: func.func @test_unsqueeze_axis_0
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: torch.constant.bool false
// CHECK: torch.constant.none
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32>
// CHECK: torch.aten.unsqueeze %arg0, %[[INT0:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32>
return %0 : !torch.vtensor<[1,3,4,5],f32>
}
@ -497,24 +482,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
// CHECK-LABEL: func.func @test_unsqueeze_axis_1
func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.unsqueeze %arg0, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32>
return %0 : !torch.vtensor<[3,1,4,5],f32>
}
@ -523,146 +492,22 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
// CHECK-LABEL: func.func @test_unsqueeze_axis_2
func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32>
return %0 : !torch.vtensor<[3,4,1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_negative_axes
func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32>
return %0 : !torch.vtensor<[1,3,1,1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_three_axes
func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64>
// CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[INT2_3:.*]] = torch.constant.int 2
// CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32>
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes
func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64>
// CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[INT2_3:.*]] = torch.constant.int 2
// CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[UNSQUEEZE_1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE]], %[[INT4]] : !torch.vtensor<[3,4,1,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1],f32>
// CHECK: %[[INT5:.*]] = torch.constant.int 5
// CHECK: torch.aten.unsqueeze %[[UNSQUEEZE_1]], %[[INT5]] : !torch.vtensor<[3,4,1,5,1],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32>
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
}