mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for Squeeze and Unsqueeze op (#2991)
This commit also cleans up the OnnxToTorch lowering for the Squeeze and Unsqueeze op and adds the support for handling edge cases. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3203/head
parent
e5bdd71baf
commit
6abc7371c8
|
@ -10,11 +10,26 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
|
||||
#include "mlir/IR/DialectResourceBlobManager.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
class Endian {
|
||||
private:
|
||||
static constexpr uint32_t uint32_ = 0x01020304;
|
||||
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;
|
||||
|
||||
public:
|
||||
static constexpr bool little = magic_ == 0x04;
|
||||
static constexpr bool big = magic_ == 0x01;
|
||||
static_assert(little || big, "Cannot determine endianness!");
|
||||
|
||||
private:
|
||||
Endian() = delete;
|
||||
};
|
||||
|
||||
namespace mlir::torch::onnx_c {
|
||||
|
||||
Value createConstantIntList(OpBinder binder,
|
||||
|
@ -28,6 +43,50 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
|
||||
bool areAllElementsDistinct(SmallVector<int64_t> array);
|
||||
|
||||
namespace detail {
|
||||
/// Matches the constant integers stored in a `onnx.Constant`.
|
||||
struct onnx_list_of_constant_ints_op_binder {
|
||||
SmallVectorImpl<int64_t> &bind_values;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||
onnx_list_of_constant_ints_op_binder(SmallVectorImpl<int64_t> &bvs)
|
||||
: bind_values(bvs) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
auto constOp = dyn_cast<Torch::OperatorOp>(op);
|
||||
if (!constOp || !constOp.getName().equals("onnx.Constant"))
|
||||
return false;
|
||||
|
||||
if (DenseResourceElementsAttr attr =
|
||||
constOp->getAttr("torch.onnx.value")
|
||||
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
|
||||
// Bytes are stored in little endian order. Big endian support will
|
||||
// require swizzling.
|
||||
if (!Endian::little) {
|
||||
op->emitError("unimplemented: importing on big endian systems");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ty = cast<ShapedType>(attr.getType());
|
||||
ElementsAttr denseAttr;
|
||||
auto ptr = attr.getRawHandle().getBlob()->getData();
|
||||
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
|
||||
for (auto axis : denseAttr.getValues<llvm::APInt>()) {
|
||||
bind_values.push_back(axis.getSExtValue());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the constant integers stored in a `onnx.Constant`.
|
||||
inline detail::onnx_list_of_constant_ints_op_binder
|
||||
m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
|
||||
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
|
||||
}
|
||||
|
||||
} // namespace mlir::torch::onnx_c
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
|
|
|
@ -142,7 +142,7 @@ m_TorchConstantBool(bool *bind_value) {
|
|||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the constant integers stored in a `torch.ListConstruct`.
|
||||
/// Matches the constant integers stored in a `torch.prim.ListConstruct`.
|
||||
struct torch_list_of_constant_ints_op_binder {
|
||||
SmallVectorImpl<int64_t> &bind_values;
|
||||
|
||||
|
|
|
@ -661,57 +661,86 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
patterns.onOp(
|
||||
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
Value axes;
|
||||
if (binder.tensorOperands(data, axes) ||
|
||||
SmallVector<Value> inputOperands;
|
||||
if (binder.tensorOperands(inputOperands, binder.op->getNumOperands()) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||
auto sizes =
|
||||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||
if (sizes.size() == 0) {
|
||||
|
||||
Value data = inputOperands[0];
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: expected input and result to have shapes");
|
||||
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
int64_t resultRank = resultType.getSizes().size();
|
||||
int64_t rankDiff = inputRank - resultRank;
|
||||
if (rankDiff == 0) {
|
||||
// In this case, no dimension is squeezed. Hence just replace the op
|
||||
// with input.
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inputOperands.size() == 1) {
|
||||
// Case: `axes` value is not present which means squeeze all the
|
||||
// dimensions with shape value 1.
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op,
|
||||
resultType, data);
|
||||
return success();
|
||||
}
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
int64_t adjustmentInt =
|
||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
adjustmentInt));
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
||||
SmallVector<Value> dimList;
|
||||
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
||||
// If the input shape and result shape is statically known then the
|
||||
// list of dims to be squeezed can be derived from those shapes. As a
|
||||
// result, we don't have to wait for the dim values to be known at
|
||||
// runtime which is also expected by the downstream pipeline.
|
||||
SmallVector<int64_t> inputShape(inputType.getSizes());
|
||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||
SmallVector<int64_t> squeezeDims;
|
||||
unsigned resultShapeCounter = 0;
|
||||
for (unsigned i = 0; i < inputRank; i++) {
|
||||
if (resultShapeCounter < resultRank &&
|
||||
inputShape[i] == resultShape[resultShapeCounter]) {
|
||||
resultShapeCounter++;
|
||||
} else {
|
||||
squeezeDims.push_back(i);
|
||||
}
|
||||
}
|
||||
for (auto i : squeezeDims) {
|
||||
dimList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
}
|
||||
|
||||
if (dimList.empty()) {
|
||||
Value axes = inputOperands[1];
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
// deal with neg axis: if (axis < 0) axis += rank
|
||||
Value isNegative =
|
||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, adjustment);
|
||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), dim, finalOffset);
|
||||
dimList.push_back(finalDim);
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
for (int i = 0; i < rankDiff; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
dimList.push_back(dim);
|
||||
}
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
dimList);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
|
||||
binder.op, resultType, data, dimValueList);
|
||||
|
@ -725,103 +754,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// discussion can be found here:
|
||||
// https://github.com/pytorch/pytorch/issues/9410
|
||||
// So, for now, we unroll into multiple unsqueezes.
|
||||
Location loc = binder.getLoc();
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
Value axes;
|
||||
Value data, axes;
|
||||
if (binder.tensorOperands(data, axes) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||
auto sizes =
|
||||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||
if (sizes.size() == 0) {
|
||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: expected input and result to have shapes");
|
||||
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
int64_t resultRank = resultType.getSizes().size();
|
||||
int64_t rankDiff = resultRank - inputRank;
|
||||
if (rankDiff == 0) {
|
||||
// In this case, no dimension is unsqueezed. Hence just replace the op
|
||||
// with input.
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
int64_t adjustmentInt =
|
||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
adjustmentInt));
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
// deal with neg axis: if (axis < 0) axis += rank
|
||||
Value isNegative =
|
||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, adjustment);
|
||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), dim, finalOffset);
|
||||
dimList.push_back(finalDim);
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
dimList);
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value updatedAxes = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(),
|
||||
axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()),
|
||||
dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse);
|
||||
// Sort the list of dims, so we don't run into this situation:
|
||||
// data.sizes = [2, 3, 4]
|
||||
// dims = [4, 0]
|
||||
// index 4 will be invalid to add a singleton dimension because
|
||||
// data.sizes.size == 3 We have to work with sorted dims to avoid this
|
||||
// situation.
|
||||
auto sortIndicesType = axesType.getWithSizesAndDtype(
|
||||
axesType.getOptionalSizes(),
|
||||
IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed));
|
||||
auto sortOpResult = rewriter.create<Torch::AtenSortOp>(
|
||||
binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero,
|
||||
cstFalse);
|
||||
Value result;
|
||||
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
||||
binder.op->getContext());
|
||||
// Go through the updated, sorted axes. Do unsqueeze for each dim.
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, sortOpResult->getResult(0),
|
||||
zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
if (sizes[0] == 1) {
|
||||
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(), resultType, data, dim);
|
||||
} else if (i == 0) {
|
||||
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(), baseType, data, dim);
|
||||
} else if (i == sizes[0] - 1) {
|
||||
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(), resultType, result, dim);
|
||||
} else {
|
||||
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(), baseType, result, dim);
|
||||
|
||||
SmallVector<int64_t> unsqueezeDims;
|
||||
SmallVector<int64_t> inputShape(inputType.getSizes());
|
||||
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
||||
// If the input shape and result shape is statically known then the
|
||||
// list of dims to be squeezed can be derived from those shapes. As a
|
||||
// result, we don't have to wait for the dim values to be known at
|
||||
// runtime which is also expected by the downstream pipeline.
|
||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||
unsigned inputShapeCounter = 0;
|
||||
for (unsigned i = 0; i < resultRank; i++) {
|
||||
if (inputShapeCounter < inputRank &&
|
||||
inputShape[inputShapeCounter] == resultShape[i]) {
|
||||
inputShapeCounter++;
|
||||
} else {
|
||||
unsqueezeDims.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SmallVector<int64_t> unsqueezeDimsInts;
|
||||
if (!matchPattern(axes, m_OnnxListOfConstantInts(unsqueezeDimsInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "only support constant int axes values");
|
||||
|
||||
for (auto dim : unsqueezeDimsInts)
|
||||
unsqueezeDims.push_back(dim < 0 ? dim + resultRank : dim);
|
||||
// If we don't sort, unsqueezing first on 4 and then on 0 would fail
|
||||
// for shape = {x,y,z}, and axes [4,0]
|
||||
llvm::sort(unsqueezeDims.begin(), unsqueezeDims.end());
|
||||
}
|
||||
Value result = data;
|
||||
SmallVector<int64_t> unsqueezeShape = inputShape;
|
||||
for (auto dim : unsqueezeDims) {
|
||||
unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1);
|
||||
Type unsqueezeType = resultType.getWithSizesAndDtype(
|
||||
unsqueezeShape, resultType.getOptionalDtype());
|
||||
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dim));
|
||||
result = rewriter.create<Torch::AtenUnsqueezeOp>(loc, unsqueezeType,
|
||||
result, cstDim);
|
||||
}
|
||||
rewriter.replaceOp(binder.op, result);
|
||||
return success();
|
||||
|
|
|
@ -2643,12 +2643,8 @@ ONNX_XFAIL_SET = {
|
|||
# Failure - onnx_lowering: onnx.ScatterElements
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||
"ScatterReduceFloatProdModuleIncludeSelf",
|
||||
"ScatterReduceFloatSumModuleIncludeSelf",
|
||||
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||
"ScatterReduceIntMinModuleIncludeSelf",
|
||||
"ScatterReduceIntProdModuleIncludeSelf",
|
||||
"ScatterReduceIntSumModuleIncludeSelf",
|
||||
"ScatterValueFloatModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ScatterND
|
||||
|
@ -2680,22 +2676,12 @@ ONNX_XFAIL_SET = {
|
|||
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
|
||||
"CrossEntropyLossModule_basic",
|
||||
"CrossEntropyLossNoReductionModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Squeeze
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeModule_broadcast",
|
||||
"SqueezeModule_static",
|
||||
|
||||
# RuntimeError: unsupported input type: Device
|
||||
"PrimsIotaModule_basic",
|
||||
|
||||
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
"BucketizeTensorFloatModule_basic",
|
||||
"BucketizeTensorModule_basic",
|
||||
"BucketizeTensorOutInt32RightModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||
"CopyWithDifferentDTypesModule_basic",
|
||||
|
@ -2712,22 +2698,16 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
"ElementwisePreluModule_basic",
|
||||
"ElementwisePreluStaticModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"GroupNormModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"TensorsStackNegativeDimModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
|
||||
|
@ -2746,6 +2726,10 @@ if torch_version_for_comparison() < version.parse('2.3.0.dev'):
|
|||
ONNX_CRASHING_SET = {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
|
||||
"ElementwisePreluModule_basic",
|
||||
"ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic",
|
||||
"ScatterReduceFloatProdModuleIncludeSelf",
|
||||
"ScatterReduceFloatSumModuleIncludeSelf",
|
||||
"ScatterReduceIntProdModuleIncludeSelf",
|
||||
"ScatterReduceIntSumModuleIncludeSelf",
|
||||
}
|
||||
|
|
|
@ -424,19 +424,34 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_squeeze_no_axes
|
||||
func.func @test_squeeze_no_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.squeeze %arg0 : !torch.vtensor<[1,3,1,4,1,5,1,1],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Squeeze"(%arg0) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_squeeze_five_axes
|
||||
func.func @test_squeeze_five_axes(%arg0: !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[INT7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT4]], %[[INT6]], %[[INT7]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.list<int> -> !torch.vtensor<[3,1,4,5],f32>
|
||||
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,4,1,5,1,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,1,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,1,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_squeeze
|
||||
func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
@ -445,24 +460,10 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten
|
|||
|
||||
// CHECK-LABEL: func.func @test_squeeze_two_axes
|
||||
func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[SQUEEZE_DIMS:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.prims.squeeze %arg0, %[[SQUEEZE_DIMS]] : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
@ -472,23 +473,7 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1:
|
|||
// CHECK-LABEL: func.func @test_unsqueeze_axis_0
|
||||
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.constant.bool false
|
||||
// CHECK: torch.constant.none
|
||||
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32>
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %[[INT0:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[1,3,4,5],f32>
|
||||
}
|
||||
|
@ -497,24 +482,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_1
|
||||
func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,1,4,5],f32>
|
||||
}
|
||||
|
@ -523,146 +492,22 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_2
|
||||
func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_negative_axes
|
||||
func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,3,1,1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_three_axes
|
||||
func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64>
|
||||
// CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[INT2_3:.*]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes
|
||||
func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64>
|
||||
// CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[INT2_3:.*]] = torch.constant.int 2
|
||||
// CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %arg0, %[[INT2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32>
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[UNSQUEEZE_1:.*]] = torch.aten.unsqueeze %[[UNSQUEEZE]], %[[INT4]] : !torch.vtensor<[3,4,1,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1],f32>
|
||||
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
||||
// CHECK: torch.aten.unsqueeze %[[UNSQUEEZE_1]], %[[INT5]] : !torch.vtensor<[3,4,1,5,1],f32>, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
%0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue