From 1201babb9fe7fb248a5195c4944335f7309fccfd Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:25:28 -0600 Subject: [PATCH] [ONNX] rework some reduction op lowerings (#3870) - Refactors more "onnx.ReduceXXX" patterns through helper function. - Fixes bug with iterating unconditionally on `output_dim == 1` during `dimList` inference. This change results in passes for the following 11 models: crossvit_15_240 crossvit_15_dagger_240 crossvit_15_dagger_408 crossvit_18_240 crossvit_18_dagger_240 crossvit_18_dagger_408 crossvit_9_240 crossvit_9_dagger_240 crossvit_base_240 crossvit_small_240 crossvit_tiny_240 --------- Co-authored-by: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 536 +++++------------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 128 ++--- 2 files changed, 191 insertions(+), 473 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1793af959..85b51ca7e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -36,21 +36,24 @@ namespace { // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for // noop_with_empty_axes handling before that. -LogicalResult reducedSumImpl(OpBinder binder, - ConversionPatternRewriter &rewriter, Value data, - Torch::ValueTensorType resultType, - Value &storeResult, int64_t keepDims, - int64_t noop_with_empty_axes, - bool isIntermediateOp) { +template +LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, + Value data, Torch::ValueTensorType resultType, + Value &storeResult, int64_t keepDims, + int64_t noop_with_empty_axes, + bool isIntermediateOp) { + auto inputType = dyn_cast(data.getType()); + if (!inputType) + return failure(); SmallVector axesList; Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected input and result to have shapes"); - } + auto axesTy = dyn_cast(axesVal.getType()); + if (!axesTy || !axesTy.areAllSizesKnown() || axesTy.getSizes().size() > 1) + return failure(); + auto axesShape = axesTy.getSizes(); + uint64_t numAxes = (axesShape.empty()) ? 1 : axesShape.front(); if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { SmallVector inputShape{inputType.getSizes()}; @@ -77,22 +80,25 @@ LogicalResult reducedSumImpl(OpBinder binder, } else { reduceDims.push_back(i); if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) + resultShape[resultShapeCounter] == 1 && keepDims == 1) resultShapeCounter++; } } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } + if (reduceDims.size() == numAxes) { + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else + binder.op->emitWarning( + "Number of inferred reduce dims, " + + std::to_string(reduceDims.size()) + + ", does not match the provided number of axes, " + + std::to_string(numAxes) + "."); } } if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + if (axesTy.getSizes()[0] == Torch::kUnknownSize) return failure(); Value zero = rewriter.create( @@ -100,9 +106,8 @@ LogicalResult reducedSumImpl(OpBinder binder, rewriter.getI64IntegerAttr(0)); SmallVector selectSizes{1}; auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { + selectSizes, axesTy.getOptionalDtype()); + for (uint64_t i = 0; i < numAxes; ++i) { Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i)); @@ -117,38 +122,60 @@ LogicalResult reducedSumImpl(OpBinder binder, SmallVector axesInts; if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); + for (int64_t i : axesInts) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); } } // Do not include absolute value in the noop - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, storeResult); + if (axesList.empty() && noop_with_empty_axes == 1) { + if (!isIntermediateOp) + rewriter.replaceOp(binder.op, data); + else + storeResult = data; return success(); } + // if the axes list is still empty, reduce everything. + if (axesList.empty()) { + if (keepDims == 0 && !resultType.getSizes().empty()) + return rewriter.notifyMatchFailure( + binder.op, + "no axes provided & no keepdim: expected result to be rank zero."); + if (keepDims == 1 && + (resultType.getSizes().size() != inputType.getSizes().size() || + llvm::any_of(resultType.getSizes(), + [](int64_t size) { return size != 1; }))) + return rewriter.notifyMatchFailure( + binder.op, "no axes provided & keepdim: expected result to have all " + "dimensions equal to 1."); + for (uint64_t i = 0; i < inputType.getSizes().size(); i++) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); + } + } + Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), axesList); Value keepDimBool = rewriter.create(binder.getLoc(), keepDims); - Value dType = rewriter.create(binder.getLoc()); - // If we are using the ReducedSum as an intermediate op to be passed into + // If we are using the reduction op as an intermediate op to be passed into // another operation, we might not want to replace the Op. So we create a new // Op and store the result in a variable. + SmallVector operands = {data, dimValueList, keepDimBool}; + if (llvm::is_one_of()) + operands.push_back( + /*dtype=*/rewriter.create(binder.getLoc())); if (!isIntermediateOp) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operands); } else { - storeResult = rewriter.create( - binder.getLoc(), resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + storeResult = rewriter.create(binder.getLoc(), + resultType, operands); } return success(); } @@ -1039,25 +1066,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp("ReduceL1", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - int64_t keepDims, noop_with_empty_axes; - Value operand; - if (binder.tensorOperandAtIndex(operand, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceL1", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t keepDims, noop_with_empty_axes; + Value operand; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - Value data = rewriter.create( - binder.getLoc(), operand.getType(), operand); + Value data = rewriter.create( + binder.getLoc(), operand.getType(), operand); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/operand, keepDims, - noop_with_empty_axes, false); - }); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); + }); patterns.onOp( "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1075,9 +1102,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value squareOfOperand = rewriter.create( binder.getLoc(), operand.getType(), operand, operand); - auto reducedSum = - reducedSumImpl(binder, rewriter, squareOfOperand, resultType, - operand, keepDims, noop_with_empty_axes, true); + auto reducedSum = reduceOpImpl( + binder, rewriter, squareOfOperand, resultType, operand, keepDims, + noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( binder.op, @@ -1112,32 +1139,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceLogSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - auto reducedSumBool = - reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + auto reducedSumBool = reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data); - return success(); - }); + rewriter.replaceOpWithNewOp(binder.op, resultType, + data); + return success(); + }); patterns.onOp( "ReduceLogSumExp", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1169,7 +1196,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, dataCast); auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); - auto reducedSumBool = reducedSumImpl( + auto reducedSumBool = reduceOpImpl( binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSumBool)) @@ -1186,22 +1213,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); - }); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + }); patterns.onOp("ReduceSumSquare", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1217,137 +1244,35 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value dataSquare = rewriter.create( binder.getLoc(), data.getType(), data, data); - return reducedSumImpl(binder, rewriter, dataSquare, - resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); + return reduceOpImpl( + binder, rewriter, dataSquare, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, + false); }); - patterns.onOp( - "ReduceMean", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); + patterns.onOp("ReduceMean", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); - SmallVector axesList; - - Value axesVal; - if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: expected input and result to have shapes"); - } - - // If the input shape and result shape is statically known then the - // list of dims to be squeezed can be derived from those shapes. As a - // result, we don't have to wait for the dim values to be known at - // runtime which is also expected by the downstream pipeline. - if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { - SmallVector inputShape{inputType.getSizes()}; - SmallVector resultShape{resultType.getSizes()}; - if (llvm::equal(inputShape, resultShape)) { - // Case: none of the dimension is reduced. - rewriter.replaceOp(binder.op, data); - return success(); - } - if (areAllElementsDistinct(inputShape)) { - // The check for the input shape elements to be distinct is added - // for the cases like: - // Input: [3, 2, 2] -> Output: [3, 2] - // For the above case, from the input and output shape it can't be - // inferred whether the dim:1 is reduced or dim:2. To avoid these - // type of cases, the check has been placed. - SmallVector reduceDims; - unsigned resultShapeCounter = 0; - for (unsigned i = 0; i < inputShape.size(); i++) { - if (resultShapeCounter < resultShape.size() && - inputShape[i] == resultShape[resultShapeCounter]) { - resultShapeCounter++; - } else { - reduceDims.push_back(i); - if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) - resultShapeCounter++; - } - } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - } - } - - if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) - return failure(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - SmallVector selectSizes{1}; - auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), selType, axesVal, zero, iv); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - } - - SmallVector axesInts; - if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); - } - } - - // deal with case when axes is empty - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - Value noneVal = rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); + Value reduceSum = data; + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/reduceSum, keepDims, noop_with_empty_axes, + false); + }); patterns.onOp( "ReduceMax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // AtenAmaxOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1412,87 +1337,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } - - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - } - - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); patterns.onOp( @@ -1501,7 +1348,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // AtenAminOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1565,87 +1411,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } - - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - } - - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 30fd60dbd..16c86218d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -707,17 +707,8 @@ func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_max_bool_inputs func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> @@ -729,17 +720,8 @@ func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> @@ -751,19 +733,9 @@ func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 // CHECK-LABEL: func.func @test_reduce_max_all_dims_default func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I0:.+]] = torch.constant.int 0 - // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> @@ -775,13 +747,7 @@ func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] @@ -793,9 +759,12 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens // CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -845,8 +814,11 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -944,7 +916,10 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1000,7 +975,10 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> @@ -1092,7 +1070,10 @@ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vte // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1177,7 +1158,10 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1385,17 +1369,8 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> @@ -1407,17 +1382,8 @@ func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> @@ -1431,17 +1397,7 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[I0:.+]] = torch.constant.int 0 // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> @@ -1453,13 +1409,7 @@ func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]]