From 58abec5c0a2da5ebcf5ec9d0dcb070004d2f3a2a Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 28 Feb 2022 11:01:23 -0800 Subject: [PATCH] Add `reduction` support to `torch.nll_loss_forward` (#624) This commit does a couple of things. First, it fixes a bug in the `linalg.generic` body of the `nll_loss_forward` lowering where the `ignoreIndex` was being compared with the loop index rather than the current element of the `target` tensor. This was not being caught by the tests because they were not testing the case where `ingnoreIndex` actually corresponds to a value in `target`. This has been fixed. Second, this commit adds support for the `reduction` argument in `torch.nll_loss_forward` as well as support for 1-D inputs. In order to simplify the lowering code, I've refactored the code that creates the `linalg.generic` ops for elementwise and reduction ops into static functions, to avoid having boilerplate code for indexing maps, etc that can be very error prone. Note: The function `convertScalarToDtype` was moved to before all the conversion patterns, but nothing in it was modified. --- e2e_testing/torchscript/nll_loss.py | 78 ++- .../TorchToLinalg/TorchToLinalg.cpp | 619 +++++++++--------- 2 files changed, 397 insertions(+), 300 deletions(-) diff --git a/e2e_testing/torchscript/nll_loss.py b/e2e_testing/torchscript/nll_loss.py index 587ef6eb2..4558fdb68 100644 --- a/e2e_testing/torchscript/nll_loss.py +++ b/e2e_testing/torchscript/nll_loss.py @@ -34,7 +34,79 @@ class NllLossModule(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule()) def NllLossModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.tensor([0, 1])) + module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + + +class NllLossModule_mean(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ]) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward(x, + target=y, + weight=None, + reduction=1, + ignore_index=2)[0] + + +@register_test_case(module_factory=lambda: NllLossModule_mean()) +def NllLossModule_mean_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + + +class NllLossModule_sum(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ]) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward(x, + target=y, + weight=None, + reduction=2, + ignore_index=2)[0] + + +@register_test_case(module_factory=lambda: NllLossModule_sum()) +def NllLossModule_sum_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + + +class NllLossModule_1D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([], torch.int64, True), + ]) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward(x, + target=y, + weight=None, + reduction=0, + ignore_index=2)[0] + + +@register_test_case(module_factory=lambda: NllLossModule_1D()) +def NllLossModule_1D_basic(module, tu: TestUtils): + module.forward(tu.rand(3), torch.randint(0, 3, ())) class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): @@ -58,8 +130,8 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) -def NllLossModule_ignore_index(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.tensor([0, 1])) +def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) class NllLossModule_backward(torch.nn.Module): diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 16deeb955..207f191cd 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -181,6 +181,14 @@ static SmallVector getTensorSizes(OpBuilder &b, Location loc, return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); } +static Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { + SmallVector sizes(getTensorSizes(b, loc, tensor)); + Value productResult = b.create(loc, b.getIndexAttr(1)); + for (Value size : sizes) + productResult = b.create(loc, productResult, size); + return castIndexToInt(b, loc, productResult); +} + static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = b.create(loc, sizes, elemTy); @@ -333,6 +341,232 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { return buildNormalCdf(b, loc, x, zero, one); } +// Convert a scalar value to the target type. The scalar value can be an element +// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype +// should be converted builtin types. +static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, + Type dtype) { + Type scalarType = scalar.getType(); + if (scalarType == dtype) + return scalar; + + // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to + // be able to know if we need signed or unsigned conversion. + auto isByteOrChar = [](Type type) { + if (auto integerTy = type.dyn_cast()) { + return integerTy.getWidth() == 8; + } + return false; + }; + + if (isByteOrChar(scalarType) || isByteOrChar(dtype) || + dtype.isSignlessInteger(1)) { + // TODO: Handle to-boolean conversion(from-boolean conversion is handled). + mlir::emitError(loc) + << "unsupported byte, char or bool type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } + + if (auto dtypeFloat = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) { + if (scalarFloat.getWidth() > dtypeFloat.getWidth()) + return b.create(loc, dtype, scalar); + // Only scalarFloat width < dtypeFloat width can reach here. + return b.create(loc, dtype, scalar); + } + assert(scalarType.isa()); + if (scalarType.isSignlessInteger(1)) + return b.create(loc, dtype, scalar); + // It's safe to use SIToFPOp because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, dtype, scalar); + } + + if (auto dtypeInteger = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) + return b.create(loc, dtype, scalar); + assert(scalarType.isa()); + auto scalarInteger = scalarType.cast(); + if (scalarInteger.getWidth() > dtypeInteger.getWidth()) + return b.create(loc, dtype, scalar); + if (scalarType.isSignlessInteger(1)) + return b.create(loc, dtype, scalar); + // Only scalarInteger width < dtypeInteger width can reach here. + // It's safe to use ExtSIOp here because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, dtype, scalar); + } + + llvm_unreachable("convertScalarToDtype should handle all the types"); +} + +// Create a reduction of `tensorOperand`, reducing along the dimensions +// in `dimSet`. If `keepDim` is true, the output tensor is the same +// rank as the `tensorOperand` and reduced dimensions are set to size 1. +// `initElem` is the element used to initialize the output tensor +// where the reduction will be stored. +static Value createReductionLinalgGeneric( + OpBuilder &b, Location loc, Value tensorOperand, + const DenseSet &dimSet, bool keepDim, Value initElem, + function_ref bodyBuild) { + auto inputType = tensorOperand.getType().cast(); + + // Get the result shape by obtaining the size of each + // dimension in the input tensor that is not getting reduced. + // If `keepDim` is true, the rank of the output tensor + // is kept the same as the rank of the input tensor, and the + // reduced dimensions are set to have size 1. + auto c1 = b.create(loc, /*value=*/1); + SmallVector resultShape; + for (int64_t i = 0; i < inputType.getRank(); i++) { + auto currentDimSize = b.create(loc, tensorOperand, i); + if (!dimSet.contains(i)) + resultShape.push_back(currentDimSize); + else if (keepDim) + resultShape.push_back(c1); + } + + // Create the affine expressions that will be used to + // iterate over the input and output tensors. + // Here we also set the type of iterator: parallel or reduction. + SmallVector exprs; + SmallVector iteratorTypes; + SmallVector resultExprs; + for (auto size : llvm::enumerate(inputType.getShape())) { + exprs.push_back(b.getAffineDimExpr(size.index())); + + if (dimSet.contains(size.index())) { + iteratorTypes.push_back(getReductionIteratorTypeName()); + // If `keepDim`, create affine map to the first element + // in the current dimension. + if (keepDim) + resultExprs.push_back(b.getAffineConstantExpr(0)); + } else { + iteratorTypes.push_back(getParallelIteratorTypeName()); + resultExprs.push_back(b.getAffineDimExpr(size.index())); + } + } + + auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); + Value accumulator = + createInitTensor(b, loc, resultShape, initElem.getType(), initElem); + + return b + .create( + loc, /*resultTensorTypes=*/accumulator.getType(), + /*inputs=*/tensorOperand, + /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) + .getResult(0); +} + +static Value createElementwiseLinalgGeneric( + OpBuilder &b, Location loc, ValueRange tensorOperands, + Type resultElementType, + function_ref bodyBuild) { + // The overall error handling strategy here is best viewed by thinking about + // what happens for a single result dimension. This loop not structured that + // way because it is hard to create the affine maps for each operand unless + // we structure the loop to iterate over tensor operands as the outer loop + // instead of inner loop. This pseudocode gives better intuition: + // ``` + // for each result dimension: + // for each tensor operand: + // if it doesn't even have high enough rank relative to the result: + // continue + // if it is a static size-1 along this result dimension: + // continue + // if this is the first tensor operand that didn't continue above: + // take its dimension size as the size of the non-broadcasted + // traversal along this dimension (this may include a dynamic size-1, + // **non-broadcasted** traversal!) + // emit error check "if the size does not match the non-broadcasted + // traversal size along this dimension, error" + // ``` + SmallVector operandRanks; + operandRanks.resize(tensorOperands.size()); + llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { + return tensor.getType().dyn_cast().getRank(); + }); + + auto resultRankIt = + std::max_element(operandRanks.begin(), operandRanks.end()); + assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); + int64_t resultRank = *resultRankIt; + + // Initialize the resultShape to all 1's, as a fallback in case + // all sizes along that result dimension are statically 1. + auto c1 = b.create(loc, /*value=*/1); + SmallVector resultShape(resultRank, c1); + SmallVector indexingMaps; + for (Value tensorOperand : tensorOperands) { + SmallVector exprs; + auto type = tensorOperand.getType().cast(); + for (auto size : llvm::enumerate(type.getShape())) { + // If the size is statically known to be 1, we don't want any + // error guards to be spuriously emitted, since we are specifically + // allowing size-1 broadcasts in this case, as they correspond to a + // constant-0 indexing map. + if (size.value() == 1) { + exprs.push_back(b.getAffineConstantExpr(0)); + continue; + } + + // The rank of this operand might be smaller than the overall rank of + // the broadcast. Add an offset to correlate it to the correct + // dimension of the result. + auto resultDim = size.index() + (resultRank - type.getRank()); + + // The generated linalg op will now be iterating along the full size + // of this dimension. Record that fact. + exprs.push_back(b.getAffineDimExpr(resultDim)); + + // Now, we need to ensure that such iteration is not going to trigger + // undefined behavior, by doing appropriate checks against the current + // dimension size. + auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); + + // If the result size of this dimension has so far only hit the + // statically-known-to-be-1 case above (i.e., we have not yet assigned a + // new Value to `resultShape[resultDim]`), then we have no other dynamic + // values to check against, and merely need to record the current + // dimension size. + if (resultShape[resultDim] == c1) { + resultShape[resultDim] = currentDimSize; + continue; + } + + // We prohibit the size-1 dynamic broadcasting scenario, so just check + // for exact equality with the running result size. + // This is the check which protects against the undefined behavior of + // the generated linalg op in the case of iterating two operands with + // dimensions sizes that are expected to match. + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } + indexingMaps.push_back(AffineMap::get( + /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); + } + + SmallVector iteratorTypes(resultRank, + getParallelIteratorTypeName()); + // Add the indexing map for the outs init tensor. + indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); + + Value initTensor = b.create( + loc, getAsOpFoldResult(resultShape), resultElementType); + return b + .create(loc, + /*resultTensorTypes=*/initTensor.getType(), + /*inputs=*/tensorOperands, + /*outputs=*/initTensor, indexingMaps, + iteratorTypes, bodyBuild) + .getResult(0); +} + namespace { class ConvertAtenAdaptiveAvgPool2dOp : public OpConversionPattern { @@ -1237,7 +1471,7 @@ public: // for i in range(0, len(target)): // indi = target[i]; // nll_loss_forward[i] = -(input[i][indi]); -// TODO: `weight` and `reduction` operands are still to be taken care of. +// TODO: `weight`operand is still to be taken care of. namespace { class ConvertAtenNllLossForwardOp : public OpConversionPattern { @@ -1253,15 +1487,10 @@ public: Value target = adaptor.target(); Value weight = adaptor.weight(); - int64_t reduce_dim; - if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim))) + int64_t reduction; + if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - // TODO: Handle reduction. - if (reduce_dim != 0) - return rewriter.notifyMatchFailure( - op, "reduction along dimensions is not supported."); - // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( @@ -1273,52 +1502,67 @@ public: unsigned inputRank = input.getType().cast().getRank(); unsigned targetRank = target.getType().cast().getRank(); - // TODO: Cases with targetRank != 1 where `Mean` reduction is required. - if (inputRank != 2 || targetRank != 1) { + // TODO: Add support for k-dim loss. + if (inputRank > 2) { return rewriter.notifyMatchFailure( - op, "expected input and target to be rank 2 and 1 respectively"); + op, "expected input and target to be rank <= 2"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - Type elementType = resultType.getElementType(); - Value targetDim = getDimOp(rewriter, loc, target, 0); - Value initTensor0 = - createZeroInitTensor(rewriter, loc, {targetDim}, elementType); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); - SmallVector targetExpr; - targetExpr.push_back(rewriter.getAffineDimExpr(0)); - SmallVector iteratorTypes{getParallelIteratorTypeName()}; - auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr}); - Value finalRes = - rewriter - .create( - loc, initTensor0.getType(), ValueRange{target}, initTensor0, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value indTarget = rewriter.create( - loc, rewriter.getIndexType(), args[0]); - Value indI = rewriter.create(loc, 0); + Value finalRes = createElementwiseLinalgGeneric( + rewriter, loc, {target}, elementType, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value targetVal = args[0]; + Value indTarget = rewriter.create( + loc, rewriter.getIndexType(), targetVal); - // The final result is given by: - // final_res = (indI == ignoreIndexVal) ? 0 : - // input[indI][IndTarget] - Value cmpEq = rewriter.create( - loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal); - Value result = rewriter.create( - loc, input, ValueRange{indI, indTarget}); - Value negate = - rewriter.create(loc, elementType, result); - Value selectFinal = rewriter.create( - loc, cmpEq, zeroVal, negate); - b.create(loc, selectFinal); - }) - .getResult(0); + // The final result is given by: + // final_res = (indTarget == ignoreIndexVal) ? 0 : + // input[indI][IndTarget] + Value cmpEq = rewriter.create( + loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); + + SmallVector extractionIndices{indTarget}; + if (inputRank == 2) { + Value indI = rewriter.create(loc, 0); + extractionIndices.insert(extractionIndices.begin(), indI); + } + + Value result = + rewriter.create(loc, input, extractionIndices); + + Value negate = + rewriter.create(loc, elementType, result); + Value selectFinal = + rewriter.create(loc, cmpEq, zeroVal, negate); + b.create(loc, selectFinal); + }); + + if (reduction == Reduction::Sum || reduction == Reduction::Mean) { + Value numOfElems = getTensorSize(rewriter, loc, finalRes); + numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); + llvm::iota_range dimsToReduce(0, targetRank, + /*inclusive=*/false); + DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); + + finalRes = createReductionLinalgGeneric( + rewriter, loc, finalRes, dimSet, /*keepDim=*/false, + /*initElem=*/zeroVal, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value newVal = args[0]; + Value accumulator = args[1]; + if (reduction == Reduction::Mean) + newVal = b.create(loc, newVal, numOfElems); + Value result = b.create(loc, newVal, accumulator); + b.create(loc, result); + }); + } // TODO: Update the second result tensor. Value weightUpdated = @@ -1588,66 +1832,6 @@ public: }; } // namespace -// Convert a scalar value to the target type. The scalar value can be an element -// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype -// should be converted builtin types. -static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, - Type dtype) { - Type scalarType = scalar.getType(); - if (scalarType == dtype) - return scalar; - - // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to - // be able to know if we need signed or unsigned conversion. - auto isByteOrChar = [](Type type) { - if (auto integerTy = type.dyn_cast()) { - return integerTy.getWidth() == 8; - } - return false; - }; - - if (isByteOrChar(scalarType) || isByteOrChar(dtype) || - dtype.isSignlessInteger(1)) { - // TODO: Handle to-boolean conversion(from-boolean conversion is handled). - mlir::emitError(loc) - << "unsupported byte, char or bool type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; - return nullptr; - } - - if (auto dtypeFloat = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) { - if (scalarFloat.getWidth() > dtypeFloat.getWidth()) - return b.create(loc, dtype, scalar); - // Only scalarFloat width < dtypeFloat width can reach here. - return b.create(loc, dtype, scalar); - } - assert(scalarType.isa()); - if (scalarType.isSignlessInteger(1)) - return b.create(loc, dtype, scalar); - // It's safe to use SIToFPOp because ui8/si8 are the only ones where - // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); - } - - if (auto dtypeInteger = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) - return b.create(loc, dtype, scalar); - assert(scalarType.isa()); - auto scalarInteger = scalarType.cast(); - if (scalarInteger.getWidth() > dtypeInteger.getWidth()) - return b.create(loc, dtype, scalar); - if (scalarType.isSignlessInteger(1)) - return b.create(loc, dtype, scalar); - // Only scalarInteger width < dtypeInteger width can reach here. - // It's safe to use ExtSIOp here because ui8/si8 are the only ones where - // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); - } - - llvm_unreachable("convertScalarToDtype should handle all the types"); -} - static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { @@ -2546,99 +2730,9 @@ struct ConvertElementwiseOp : ConversionPattern { auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - auto resultRank = resultType.getRank(); - - auto c1 = rewriter.create(loc, /*value=*/1); - // The overall error handling strategy here is best viewed by thinking about - // what happens for a single result dimension. This loop not structured that - // way because it is hard to create the affine maps for each operand unless - // we structure the loop to iterate over tensor operands as the outer loop - // instead of inner loop. This pseudocode gives better intuition: - // ``` - // for each result dimension: - // for each tensor operand: - // if it doesn't even have high enough rank relative to the result: - // continue - // if it is a static size-1 along this result dimension: - // continue - // if this is the first tensor operand that didn't continue above: - // take its dimension size as the size of the non-broadcasted - // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) - // emit error check "if the size does not match the non-broadcasted - // traversal size along this dimension, error" - // ``` - // Initialize the resultShape to all 1's, as a fallback in case - // all sizes along that result dimension are statically 1. - SmallVector resultShape(resultRank, c1); - SmallVector indexingMaps; - for (Value tensorOperand : tensorOperands) { - SmallVector exprs; - auto type = tensorOperand.getType().cast(); - for (auto size : llvm::enumerate(type.getShape())) { - // If the size is statically known to be 1, we don't want any - // error guards to be spuriously emitted, since we are specifically - // allowing size-1 broadcasts in this case, as they correspond to a - // constant-0 indexing map. - if (size.value() == 1) { - exprs.push_back(rewriter.getAffineConstantExpr(0)); - continue; - } - - // The rank of this operand might be smaller than the overall rank of - // the broadcast. Add an offset to correlate it to the correct - // dimension of the result. - auto resultDim = size.index() + (resultRank - type.getRank()); - - // The generated linalg op will now be iterating along the full size - // of this dimension. Record that fact. - exprs.push_back(rewriter.getAffineDimExpr(resultDim)); - - // Now, we need to ensure that such iteration is not going to trigger - // undefined behavior, by doing appropriate checks against the current - // dimension size. - auto currentDimSize = - getDimOp(rewriter, loc, tensorOperand, size.index()); - - // If the result size of this dimension has so far only hit the - // statically-known-to-be-1 case above (i.e., we have not yet assigned a - // new Value to `resultShape[resultDim]`), then we have no other dynamic - // values to check against, and merely need to record the current - // dimension size. - if (resultShape[resultDim] == c1) { - resultShape[resultDim] = currentDimSize; - continue; - } - - // We prohibit the size-1 dynamic broadcasting scenario, so just check - // for exact equality with the running result size. - // This is the check which protects against the undefined behavior of - // the generated linalg op in the case of iterating two operands with - // dimensions sizes that are expected to match. - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, resultShape[resultDim], - currentDimSize); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); - } - indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext())); - } - - SmallVector iteratorTypes(resultRank, - getParallelIteratorTypeName()); - // Add the indexing map for the outs init tensor. - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); - - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), resultType.getElementType()); bool hadErrorCreatingPayload = false; - auto generic = rewriter.create( - loc, /*resultTensorTypes=*/initTensor.getType(), - /*inputs=*/tensorOperands, - /*outputs=*/initTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, + Value generic = createElementwiseLinalgGeneric( + rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( b, loc, getTypeConverter(), payloadArgs, op, operands); @@ -2650,8 +2744,7 @@ struct ConvertElementwiseOp : ConversionPattern { }); if (hadErrorCreatingPayload) return failure(); - rewriter.replaceOpWithNewOp(op, resultType, - generic.getResult(0)); + rewriter.replaceOpWithNewOp(op, resultType, generic); return success(); } }; @@ -2662,104 +2755,19 @@ struct ConvertReductionOp : ConversionPattern { ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - - // This function is in charge of all the rewriting that will take - // place in `matchAndRewrite`. In particular, it converts - // the reduce operation into an `linalg.generic` operation - // to reduce the input tensor along the dimensions specified in - // `dimeSet`. - LogicalResult - createReductionLinalgGeneric(Operation *op, ArrayRef operands, - const DenseSet &dimSet, bool keepDim, - ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); - auto tensorOperand = operands[0]; - auto inputType = tensorOperand.getType().cast(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - - // Get the result shape by obtaining the size of each - // dimension in the input tensor that is not getting reduced. - // If `keepDim` is true, the rank of the output tensor - // is kept the same as the rank of the input tensor, and the - // reduced dimensions are set to have size 1. - auto c1 = rewriter.create(loc, /*value=*/1); - SmallVector resultShape; - for (int64_t i = 0; i < inputType.getRank(); i++) { - auto currentDimSize = - rewriter.create(loc, tensorOperand, i); - if (!dimSet.contains(i)) - resultShape.push_back(currentDimSize); - else if (keepDim) - resultShape.push_back(c1); - } - - // Create the affine expressions that will be used to - // iterate over the input and output tensors. - // Here we also set the type of iterator: parallel or reduction. - SmallVector exprs; - SmallVector iteratorTypes; - SmallVector resultExprs; - for (auto size : llvm::enumerate(inputType.getShape())) { - exprs.push_back(rewriter.getAffineDimExpr(size.index())); - - if (dimSet.contains(size.index())) { - iteratorTypes.push_back(getReductionIteratorTypeName()); - // If `keepDim`, create affine map to the first element - // in the current dimension. - if (keepDim) - resultExprs.push_back(rewriter.getAffineConstantExpr(0)); - } else { - iteratorTypes.push_back(getParallelIteratorTypeName()); - resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); - } - } - - auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); - Value initTensor = rewriter.create( - loc, resultShape, resultType.getElementType()); - Value initValue = createLinalgNeutralElementForReduceOp( - rewriter, loc, op, resultType.getElementType()); - Value accumulator = - rewriter.create(loc, initValue, initTensor) - .getResult(0); - bool hadErrorCreatingPayload = false; - auto generic = rewriter.create( - loc, /*resultTensorTypes=*/accumulator.getType(), - /*inputs=*/tensorOperand, - /*outputs=*/accumulator, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { - Value result = createLinalgPayloadCalculationForReduceOp( - b, loc, payloadArgs, op, operands, resultType.getElementType()); - if (!result) { - hadErrorCreatingPayload = true; - return; - } - b.create(loc, result); - }); - - if (hadErrorCreatingPayload) - return failure(); - rewriter.replaceOpWithNewOp(op, resultType, - generic.getResult(0)); - return success(); - } - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - // Every reduce operation must set a value for the `dimSet` and - // `keepDim` in accordance with their specification. + // Every reduce operation must set a value for the `dimSet`, + // `tensorOperand`, and `keepDim` in accordance with their specification. DenseSet dimSet; + Value tensorOperand; bool keepDim = false; if (isa(op) || isa(op)) { - auto tensorOperand = operands[0]; + tensorOperand = operands[0]; auto inputType = tensorOperand.getType().cast(); // `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the @@ -2767,7 +2775,7 @@ struct ConvertReductionOp : ConversionPattern { for (int64_t i = 0; i < inputType.getRank(); i++) dimSet.insert(i); } else if (auto sumDimIntListOp = dyn_cast(op)) { - auto tensorOperand = operands[0]; + tensorOperand = operands[0]; auto inputType = tensorOperand.getType().cast(); if (!matchPattern(sumDimIntListOp.keepdim(), @@ -2788,8 +2796,31 @@ struct ConvertReductionOp : ConversionPattern { } else { return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } - return createReductionLinalgGeneric(op, operands, dimSet, keepDim, - rewriter); + + Location loc = op->getLoc(); + auto resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Value initElem = createLinalgNeutralElementForReduceOp( + rewriter, loc, op, resultType.getElementType()); + + bool hadErrorCreatingPayload = false; + Value generic = createReductionLinalgGeneric( + rewriter, loc, tensorOperand, dimSet, keepDim, initElem, + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = createLinalgPayloadCalculationForReduceOp( + b, loc, payloadArgs, op, operands, resultType.getElementType()); + if (!result) { + hadErrorCreatingPayload = true; + return; + } + b.create(loc, result); + }); + + if (hadErrorCreatingPayload) + return failure(); + rewriter.replaceOpWithNewOp(op, resultType, generic); + return success(); } }; } // namespace @@ -4251,14 +4282,8 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - Value self = adaptor.self(); - SmallVector sizes(getTensorSizes(rewriter, loc, self)); - Value productResult = - rewriter.create(loc, rewriter.getIndexAttr(1)); - for (size_t i = 0; i < sizes.size(); i++) - productResult = - rewriter.create(loc, productResult, sizes[i]); - rewriter.replaceOp(op, castIndexToInt(rewriter, loc, productResult)); + Value tensorSize = getTensorSize(rewriter, loc, adaptor.self()); + rewriter.replaceOp(op, tensorSize); return success(); } };