mirror of https://github.com/llvm/torch-mlir
[ONNX] rework some reduction op lowerings (#3870)
- Refactors more "onnx.ReduceXXX" patterns through helper function. - Fixes bug with iterating unconditionally on `output_dim == 1` during `dimList` inference. This change results in passes for the following 11 models: crossvit_15_240 crossvit_15_dagger_240 crossvit_15_dagger_408 crossvit_18_240 crossvit_18_dagger_240 crossvit_18_dagger_408 crossvit_9_240 crossvit_9_dagger_240 crossvit_base_240 crossvit_small_240 crossvit_tiny_240 --------- Co-authored-by: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com>pull/3872/merge
parent
30c519369e
commit
1201babb9f
|
@ -36,21 +36,24 @@ namespace {
|
|||
// we provide the original operand through storeResult, which will be modified
|
||||
// if the result will be passed onto another operation, and will be used for
|
||||
// noop_with_empty_axes handling before that.
|
||||
LogicalResult reducedSumImpl(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter, Value data,
|
||||
Torch::ValueTensorType resultType,
|
||||
Value &storeResult, int64_t keepDims,
|
||||
int64_t noop_with_empty_axes,
|
||||
bool isIntermediateOp) {
|
||||
template <typename AtenReductionTypeOp>
|
||||
LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
Value data, Torch::ValueTensorType resultType,
|
||||
Value &storeResult, int64_t keepDims,
|
||||
int64_t noop_with_empty_axes,
|
||||
bool isIntermediateOp) {
|
||||
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType)
|
||||
return failure();
|
||||
SmallVector<Value> axesList;
|
||||
Value axesVal;
|
||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: expected input and result to have shapes");
|
||||
}
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
if (!axesTy || !axesTy.areAllSizesKnown() || axesTy.getSizes().size() > 1)
|
||||
return failure();
|
||||
auto axesShape = axesTy.getSizes();
|
||||
uint64_t numAxes = (axesShape.empty()) ? 1 : axesShape.front();
|
||||
|
||||
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
||||
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
||||
|
@ -77,22 +80,25 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
|||
} else {
|
||||
reduceDims.push_back(i);
|
||||
if (resultShapeCounter < resultShape.size() &&
|
||||
resultShape[resultShapeCounter] == 1)
|
||||
resultShape[resultShapeCounter] == 1 && keepDims == 1)
|
||||
resultShapeCounter++;
|
||||
}
|
||||
}
|
||||
for (auto i : reduceDims) {
|
||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
if (reduceDims.size() == numAxes) {
|
||||
for (auto i : reduceDims) {
|
||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
} else
|
||||
binder.op->emitWarning(
|
||||
"Number of inferred reduce dims, " +
|
||||
std::to_string(reduceDims.size()) +
|
||||
", does not match the provided number of axes, " +
|
||||
std::to_string(numAxes) + ".");
|
||||
}
|
||||
}
|
||||
if (axesList.empty()) {
|
||||
Torch::BaseTensorType axesType =
|
||||
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
auto axesShape = axesTy.getSizes();
|
||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||
if (axesTy.getSizes()[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -100,9 +106,8 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
|||
rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
int64_t numAxes = axesShape[0];
|
||||
for (int64_t i = 0; i < numAxes; ++i) {
|
||||
selectSizes, axesTy.getOptionalDtype());
|
||||
for (uint64_t i = 0; i < numAxes; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(i));
|
||||
|
@ -117,38 +122,60 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
|||
|
||||
SmallVector<int64_t> axesInts;
|
||||
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
||||
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(axesInts[i]));
|
||||
axesList.push_back(iv);
|
||||
for (int64_t i : axesInts) {
|
||||
axesList.push_back(
|
||||
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i));
|
||||
}
|
||||
}
|
||||
|
||||
// Do not include absolute value in the noop
|
||||
if (axesList.empty() && noop_with_empty_axes) {
|
||||
rewriter.replaceOp(binder.op, storeResult);
|
||||
if (axesList.empty() && noop_with_empty_axes == 1) {
|
||||
if (!isIntermediateOp)
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
else
|
||||
storeResult = data;
|
||||
return success();
|
||||
}
|
||||
|
||||
// if the axes list is still empty, reduce everything.
|
||||
if (axesList.empty()) {
|
||||
if (keepDims == 0 && !resultType.getSizes().empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"no axes provided & no keepdim: expected result to be rank zero.");
|
||||
if (keepDims == 1 &&
|
||||
(resultType.getSizes().size() != inputType.getSizes().size() ||
|
||||
llvm::any_of(resultType.getSizes(),
|
||||
[](int64_t size) { return size != 1; })))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "no axes provided & keepdim: expected result to have all "
|
||||
"dimensions equal to 1.");
|
||||
for (uint64_t i = 0; i < inputType.getSizes().size(); i++) {
|
||||
axesList.push_back(
|
||||
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i));
|
||||
}
|
||||
}
|
||||
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
axesList);
|
||||
Value keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||
Value dType = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
// If we are using the ReducedSum as an intermediate op to be passed into
|
||||
// If we are using the reduction op as an intermediate op to be passed into
|
||||
// another operation, we might not want to replace the Op. So we create a new
|
||||
// Op and store the result in a variable.
|
||||
SmallVector<Value> operands = {data, dimValueList, keepDimBool};
|
||||
if (llvm::is_one_of<AtenReductionTypeOp, Torch::AtenSumDimIntListOp,
|
||||
Torch::AtenMeanDimOp>())
|
||||
operands.push_back(
|
||||
/*dtype=*/rewriter.create<Torch::ConstantNoneOp>(binder.getLoc()));
|
||||
if (!isIntermediateOp) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||
/*dtype=*/dType);
|
||||
rewriter.replaceOpWithNewOp<AtenReductionTypeOp>(binder.op, resultType,
|
||||
operands);
|
||||
} else {
|
||||
storeResult = rewriter.create<Torch::AtenSumDimIntListOp>(
|
||||
binder.getLoc(), resultType, data, dimValueList, keepDimBool,
|
||||
/*dtype=*/dType);
|
||||
storeResult = rewriter.create<AtenReductionTypeOp>(binder.getLoc(),
|
||||
resultType, operands);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -1039,25 +1066,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceL1", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
Value operand;
|
||||
if (binder.tensorOperandAtIndex(operand, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
patterns.onOp(
|
||||
"ReduceL1", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
Value operand;
|
||||
if (binder.tensorOperandAtIndex(operand, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
|
||||
Value data = rewriter.create<Torch::AtenAbsOp>(
|
||||
binder.getLoc(), operand.getType(), operand);
|
||||
Value data = rewriter.create<Torch::AtenAbsOp>(
|
||||
binder.getLoc(), operand.getType(), operand);
|
||||
|
||||
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/operand, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
return reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/operand, keepDims, noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
@ -1075,9 +1102,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value squareOfOperand = rewriter.create<Torch::AtenMulTensorOp>(
|
||||
binder.getLoc(), operand.getType(), operand, operand);
|
||||
|
||||
auto reducedSum =
|
||||
reducedSumImpl(binder, rewriter, squareOfOperand, resultType,
|
||||
operand, keepDims, noop_with_empty_axes, true);
|
||||
auto reducedSum = reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, squareOfOperand, resultType, operand, keepDims,
|
||||
noop_with_empty_axes, true);
|
||||
if (failed(reducedSum))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
|
@ -1112,32 +1139,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*memory_format=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceLogSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
patterns.onOp(
|
||||
"ReduceLogSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
|
||||
auto reducedSumBool =
|
||||
reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, true);
|
||||
auto reducedSumBool = reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes, true);
|
||||
|
||||
if (failed(reducedSumBool))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Failed to perform sum operation on square of operand");
|
||||
if (failed(reducedSumBool))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Failed to perform sum operation on square of operand");
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
|
||||
binder.op, resultType, data);
|
||||
return success();
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(binder.op, resultType,
|
||||
data);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceLogSumExp", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
@ -1169,7 +1196,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.getLoc(), f64ResultType, dataCast);
|
||||
auto f64ReduceType = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(), rewriter.getF64Type());
|
||||
auto reducedSumBool = reducedSumImpl(
|
||||
auto reducedSumBool = reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, dataExp, f64ReduceType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes, true);
|
||||
if (failed(reducedSumBool))
|
||||
|
@ -1186,22 +1213,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*memory_format=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("ReduceSum", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
patterns.onOp(
|
||||
"ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
|
||||
return reducedSumImpl(binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
});
|
||||
return reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes, false);
|
||||
});
|
||||
patterns.onOp("ReduceSumSquare", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
@ -1217,137 +1244,35 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value dataSquare = rewriter.create<Torch::AtenMulTensorOp>(
|
||||
binder.getLoc(), data.getType(), data, data);
|
||||
|
||||
return reducedSumImpl(binder, rewriter, dataSquare,
|
||||
resultType,
|
||||
/*storeValue=*/data, keepDims,
|
||||
noop_with_empty_axes, false);
|
||||
return reduceOpImpl<Torch::AtenSumDimIntListOp>(
|
||||
binder, rewriter, dataSquare, resultType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes,
|
||||
false);
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceMean", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
patterns.onOp("ReduceMean", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
int64_t keepDims, noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||
"noop_with_empty_axes", 0))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> axesList;
|
||||
|
||||
Value axesVal;
|
||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: expected input and result to have shapes");
|
||||
}
|
||||
|
||||
// If the input shape and result shape is statically known then the
|
||||
// list of dims to be squeezed can be derived from those shapes. As a
|
||||
// result, we don't have to wait for the dim values to be known at
|
||||
// runtime which is also expected by the downstream pipeline.
|
||||
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
||||
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
||||
SmallVector<int64_t> resultShape{resultType.getSizes()};
|
||||
if (llvm::equal(inputShape, resultShape)) {
|
||||
// Case: none of the dimension is reduced.
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
if (areAllElementsDistinct(inputShape)) {
|
||||
// The check for the input shape elements to be distinct is added
|
||||
// for the cases like:
|
||||
// Input: [3, 2, 2] -> Output: [3, 2]
|
||||
// For the above case, from the input and output shape it can't be
|
||||
// inferred whether the dim:1 is reduced or dim:2. To avoid these
|
||||
// type of cases, the check has been placed.
|
||||
SmallVector<int64_t> reduceDims;
|
||||
unsigned resultShapeCounter = 0;
|
||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
||||
if (resultShapeCounter < resultShape.size() &&
|
||||
inputShape[i] == resultShape[resultShapeCounter]) {
|
||||
resultShapeCounter++;
|
||||
} else {
|
||||
reduceDims.push_back(i);
|
||||
if (resultShapeCounter < resultShape.size() &&
|
||||
resultShape[resultShapeCounter] == 1)
|
||||
resultShapeCounter++;
|
||||
}
|
||||
}
|
||||
for (auto i : reduceDims) {
|
||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (axesList.empty()) {
|
||||
Torch::BaseTensorType axesType =
|
||||
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
auto axesShape = axesTy.getSizes();
|
||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
int64_t numAxes = axesShape[0];
|
||||
for (int64_t i = 0; i < numAxes; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selType, axesVal, zero, iv);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
axesList.push_back(dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> axesInts;
|
||||
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
||||
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(axesInts[i]));
|
||||
axesList.push_back(iv);
|
||||
}
|
||||
}
|
||||
|
||||
// deal with case when axes is empty
|
||||
if (axesList.empty() && noop_with_empty_axes) {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
axesList);
|
||||
Value keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||
/*dtype=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
Value reduceSum = data;
|
||||
return reduceOpImpl<Torch::AtenMeanDimOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/reduceSum, keepDims, noop_with_empty_axes,
|
||||
false);
|
||||
});
|
||||
patterns.onOp(
|
||||
"ReduceMax", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// AtenAmaxOp allows us to pass a list of dims
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
Value axes;
|
||||
int64_t keepDims;
|
||||
int64_t noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
|
@ -1412,87 +1337,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Previous version of the operation had the axes as an attribute:
|
||||
SmallVector<Value> axesList;
|
||||
llvm::SmallVector<int64_t> axesAttr;
|
||||
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
||||
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), torchIntTy,
|
||||
rewriter.getI64IntegerAttr(axesAttr[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the axes values from the axes operand:
|
||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
auto sizes = axesType.getSizes();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
// Extract the value of each axes:
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
axesList.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the noop case:
|
||||
if (axesList.empty() && noop_with_empty_axes) {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Deal with case when no axes arg is passed but not a noop:
|
||||
if (axesList.empty()) {
|
||||
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||
.getSizes()
|
||||
.size();
|
||||
for (int i = 0; i < numDims; i++) {
|
||||
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
axesList.push_back(curr);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle negative axis:
|
||||
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
||||
torchIntTy, data);
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
for (Value &axes : axesList) {
|
||||
Value isNegative =
|
||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, rankVal);
|
||||
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
||||
finalOffset);
|
||||
}
|
||||
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
||||
Value keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool);
|
||||
return success();
|
||||
return reduceOpImpl<Torch::AtenAmaxOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes, false);
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
|
@ -1501,7 +1348,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
// AtenAminOp allows us to pass a list of dims
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
Value axes;
|
||||
int64_t keepDims;
|
||||
int64_t noop_with_empty_axes;
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
|
@ -1565,87 +1411,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Previous version of the operation had the axes as an attribute:
|
||||
SmallVector<Value> axesList;
|
||||
llvm::SmallVector<int64_t> axesAttr;
|
||||
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
||||
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
||||
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), torchIntTy,
|
||||
rewriter.getI64IntegerAttr(axesAttr[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the axes values from the axes operand:
|
||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
cast<Torch::BaseTensorType>(axes.getType());
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
auto sizes = axesType.getSizes();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
// Extract the value of each axes:
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
axesList.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the noop case:
|
||||
if (axesList.empty() && noop_with_empty_axes) {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Deal with case when no axes arg is passed but not a noop:
|
||||
if (axesList.empty()) {
|
||||
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||
.getSizes()
|
||||
.size();
|
||||
for (int i = 0; i < numDims; i++) {
|
||||
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
axesList.push_back(curr);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle negative axis:
|
||||
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
||||
torchIntTy, data);
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
for (Value &axes : axesList) {
|
||||
Value isNegative =
|
||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, rankVal);
|
||||
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
||||
finalOffset);
|
||||
}
|
||||
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
||||
Value keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool);
|
||||
return success();
|
||||
return reduceOpImpl<Torch::AtenAminOp>(
|
||||
binder, rewriter, data, resultType,
|
||||
/*storeValue=*/data, keepDims, noop_with_empty_axes, false);
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
|
|
|
@ -707,17 +707,8 @@ func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs
|
||||
func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
||||
// CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1>
|
||||
|
@ -729,17 +720,8 @@ func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims
|
||||
func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||
// CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1>
|
||||
|
@ -751,19 +733,9 @@ func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_max_all_dims_default
|
||||
func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[MAX]] : !torch.vtensor<[],i1>
|
||||
|
@ -775,13 +747,7 @@ func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) ->
|
|||
|
||||
func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||
// CHECK: return %[[AMAX]]
|
||||
|
@ -793,9 +759,12 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example
|
||||
func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
|
@ -845,8 +814,11 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f
|
|||
// CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example
|
||||
func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE_0:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE_0:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
|
@ -944,7 +916,10 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2
|
|||
// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
|
@ -1000,7 +975,10 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v
|
|||
// CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64>
|
||||
// CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE_1:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64>
|
||||
|
@ -1092,7 +1070,10 @@ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vte
|
|||
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
||||
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
|
@ -1177,7 +1158,10 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
|
|||
func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
|
@ -1385,17 +1369,8 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
|
||||
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
||||
// CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1>
|
||||
|
@ -1407,17 +1382,8 @@ func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims
|
||||
func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||
// CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1>
|
||||
|
@ -1431,17 +1397,7 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1
|
|||
func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[MIN]] : !torch.vtensor<[],i1>
|
||||
|
@ -1453,13 +1409,7 @@ func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) ->
|
|||
|
||||
func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||
// CHECK: return %[[AMIN]]
|
||||
|
|
Loading…
Reference in New Issue