Add `reduction` support to `torch.nll_loss_forward` (#624)

This commit does a couple of things. First, it fixes a bug in the
`linalg.generic` body of the `nll_loss_forward` lowering where the
`ignoreIndex` was being compared with the loop index rather than the
current element of the `target` tensor. This was not being caught by
the tests because they were not testing the case where `ingnoreIndex`
actually corresponds to a value in `target`. This has been fixed.

Second, this commit adds support for the `reduction` argument in
`torch.nll_loss_forward` as well as support for 1-D inputs. In order
to simplify the lowering code, I've refactored the code that creates
the `linalg.generic` ops for elementwise and reduction ops into static
functions, to avoid having boilerplate code for indexing maps, etc
that can be very error prone.

Note: The function `convertScalarToDtype` was moved to before all the
conversion patterns, but nothing in it was modified.
pull/628/head
Ramiro Leal-Cavazos 2022-02-28 11:01:23 -08:00 committed by GitHub
parent 9b2613533b
commit 58abec5c0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 397 additions and 300 deletions

View File

@ -34,7 +34,79 @@ class NllLossModule(torch.nn.Module):
@register_test_case(module_factory=lambda: NllLossModule())
def NllLossModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
class NllLossModule_mean(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
])
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(x,
target=y,
weight=None,
reduction=1,
ignore_index=2)[0]
@register_test_case(module_factory=lambda: NllLossModule_mean())
def NllLossModule_mean_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
class NllLossModule_sum(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
])
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(x,
target=y,
weight=None,
reduction=2,
ignore_index=2)[0]
@register_test_case(module_factory=lambda: NllLossModule_sum())
def NllLossModule_sum_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
class NllLossModule_1D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([], torch.int64, True),
])
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(x,
target=y,
weight=None,
reduction=0,
ignore_index=2)[0]
@register_test_case(module_factory=lambda: NllLossModule_1D())
def NllLossModule_1D_basic(module, tu: TestUtils):
module.forward(tu.rand(3), torch.randint(0, 3, ()))
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
@ -58,8 +130,8 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
def NllLossModule_ignore_index(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
class NllLossModule_backward(torch.nn.Module):

View File

@ -181,6 +181,14 @@ static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
}
static Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
SmallVector<Value> sizes(getTensorSizes(b, loc, tensor));
Value productResult = b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
for (Value size : sizes)
productResult = b.create<arith::MulIOp>(loc, productResult, size);
return castIndexToInt(b, loc, productResult);
}
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
@ -333,6 +341,232 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
return buildNormalCdf(b, loc, x, zero, one);
}
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
return integerTy.getWidth() == 8;
}
return false;
};
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
dtype.isSignlessInteger(1)) {
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
mlir::emitError(loc)
<< "unsupported byte, char or bool type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, dtype, scalar);
}
assert(scalarType.isa<mlir::IntegerType>());
if (scalarType.isSignlessInteger(1))
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
}
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
assert(scalarType.isa<mlir::IntegerType>());
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1))
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}
llvm_unreachable("convertScalarToDtype should handle all the types");
}
// Create a reduction of `tensorOperand`, reducing along the dimensions
// in `dimSet`. If `keepDim` is true, the output tensor is the same
// rank as the `tensorOperand` and reduced dimensions are set to size 1.
// `initElem` is the element used to initialize the output tensor
// where the reduction will be stored.
static Value createReductionLinalgGeneric(
OpBuilder &b, Location loc, Value tensorOperand,
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
// Get the result shape by obtaining the size of each
// dimension in the input tensor that is not getting reduced.
// If `keepDim` is true, the rank of the output tensor
// is kept the same as the rank of the input tensor, and the
// reduced dimensions are set to have size 1.
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
auto currentDimSize = b.create<tensor::DimOp>(loc, tensorOperand, i);
if (!dimSet.contains(i))
resultShape.push_back(currentDimSize);
else if (keepDim)
resultShape.push_back(c1);
}
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(b.getAffineDimExpr(size.index()));
if (dimSet.contains(size.index())) {
iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(b.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(getParallelIteratorTypeName());
resultExprs.push_back(b.getAffineDimExpr(size.index()));
}
}
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
Value accumulator =
createInitTensor(b, loc, resultShape, initElem.getType(), initElem);
return b
.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/accumulator.getType(),
/*inputs=*/tensorOperand,
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
.getResult(0);
}
static Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
// The overall error handling strategy here is best viewed by thinking about
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:
// if it doesn't even have high enough rank relative to the result:
// continue
// if it is a static size-1 along this result dimension:
// continue
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
SmallVector<int64_t> operandRanks;
operandRanks.resize(tensorOperands.size());
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
});
auto resultRankIt =
std::max_element(operandRanks.begin(), operandRanks.end());
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
int64_t resultRank = *resultRankIt;
// Initialize the resultShape to all 1's, as a fallback in case
// all sizes along that result dimension are statically 1.
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
for (auto size : llvm::enumerate(type.getShape())) {
// If the size is statically known to be 1, we don't want any
// error guards to be spuriously emitted, since we are specifically
// allowing size-1 broadcasts in this case, as they correspond to a
// constant-0 indexing map.
if (size.value() == 1) {
exprs.push_back(b.getAffineConstantExpr(0));
continue;
}
// The rank of this operand might be smaller than the overall rank of
// the broadcast. Add an offset to correlate it to the correct
// dimension of the result.
auto resultDim = size.index() + (resultRank - type.getRank());
// The generated linalg op will now be iterating along the full size
// of this dimension. Record that fact.
exprs.push_back(b.getAffineDimExpr(resultDim));
// Now, we need to ensure that such iteration is not going to trigger
// undefined behavior, by doing appropriate checks against the current
// dimension size.
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
// If the result size of this dimension has so far only hit the
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
// new Value to `resultShape[resultDim]`), then we have no other dynamic
// values to check against, and merely need to record the current
// dimension size.
if (resultShape[resultDim] == c1) {
resultShape[resultDim] = currentDimSize;
continue;
}
// We prohibit the size-1 dynamic broadcasting scenario, so just check
// for exact equality with the running result size.
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
}
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
// Add the indexing map for the outs init tensor.
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
Value initTensor = b.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(resultShape), resultElementType);
return b
.create<linalg::GenericOp>(loc,
/*resultTensorTypes=*/initTensor.getType(),
/*inputs=*/tensorOperands,
/*outputs=*/initTensor, indexingMaps,
iteratorTypes, bodyBuild)
.getResult(0);
}
namespace {
class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
@ -1237,7 +1471,7 @@ public:
// for i in range(0, len(target)):
// indi = target[i];
// nll_loss_forward[i] = -(input[i][indi]);
// TODO: `weight` and `reduction` operands are still to be taken care of.
// TODO: `weight`operand is still to be taken care of.
namespace {
class ConvertAtenNllLossForwardOp
: public OpConversionPattern<AtenNllLossForwardOp> {
@ -1253,15 +1487,10 @@ public:
Value target = adaptor.target();
Value weight = adaptor.weight();
int64_t reduce_dim;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
int64_t reduction;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Handle reduction.
if (reduce_dim != 0)
return rewriter.notifyMatchFailure(
op, "reduction along dimensions is not supported.");
// TODO: Incorporate the weight argument.
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
return rewriter.notifyMatchFailure(
@ -1273,52 +1502,67 @@ public:
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
if (inputRank != 2 || targetRank != 1) {
// TODO: Add support for k-dim loss.
if (inputRank > 2) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank 2 and 1 respectively");
op, "expected input and target to be rank <= 2");
}
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
Value targetDim = getDimOp(rewriter, loc, target, 0);
Value initTensor0 =
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
SmallVector<AffineExpr> targetExpr;
targetExpr.push_back(rewriter.getAffineDimExpr(0));
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor0.getType(), ValueRange{target}, initTensor0,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
Value finalRes = createElementwiseLinalgGeneric(
rewriter, loc, {target}, elementType,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0];
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
loc, rewriter.getIndexType(), targetVal);
// The final result is given by:
// final_res = (indI == ignoreIndexVal) ? 0 :
// final_res = (indTarget == ignoreIndexVal) ? 0 :
// input[indI][IndTarget]
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
Value result = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{indI, indTarget});
loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal);
SmallVector<Value> extractionIndices{indTarget};
if (inputRank == 2) {
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
extractionIndices.insert(extractionIndices.begin(), indI);
}
Value result =
rewriter.create<tensor::ExtractOp>(loc, input, extractionIndices);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, result);
Value selectFinal = rewriter.create<arith::SelectOp>(
loc, cmpEq, zeroVal, negate);
Value selectFinal =
rewriter.create<arith::SelectOp>(loc, cmpEq, zeroVal, negate);
b.create<linalg::YieldOp>(loc, selectFinal);
})
.getResult(0);
});
if (reduction == Reduction::Sum || reduction == Reduction::Mean) {
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
/*inclusive=*/false);
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
finalRes = createReductionLinalgGeneric(
rewriter, loc, finalRes, dimSet, /*keepDim=*/false,
/*initElem=*/zeroVal,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value newVal = args[0];
Value accumulator = args[1];
if (reduction == Reduction::Mean)
newVal = b.create<arith::DivFOp>(loc, newVal, numOfElems);
Value result = b.create<arith::AddFOp>(loc, newVal, accumulator);
b.create<linalg::YieldOp>(loc, result);
});
}
// TODO: Update the second result tensor.
Value weightUpdated =
@ -1588,66 +1832,6 @@ public:
};
} // namespace
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
return integerTy.getWidth() == 8;
}
return false;
};
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
dtype.isSignlessInteger(1)) {
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
mlir::emitError(loc)
<< "unsupported byte, char or bool type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, dtype, scalar);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, dtype, scalar);
}
assert(scalarType.isa<mlir::IntegerType>());
if (scalarType.isSignlessInteger(1))
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
}
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
assert(scalarType.isa<mlir::IntegerType>());
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, dtype, scalar);
if (scalarType.isSignlessInteger(1))
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
}
llvm_unreachable("convertScalarToDtype should handle all the types");
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -2546,99 +2730,9 @@ struct ConvertElementwiseOp : ConversionPattern {
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto resultRank = resultType.getRank();
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
// The overall error handling strategy here is best viewed by thinking about
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:
// if it doesn't even have high enough rank relative to the result:
// continue
// if it is a static size-1 along this result dimension:
// continue
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
// Initialize the resultShape to all 1's, as a fallback in case
// all sizes along that result dimension are statically 1.
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
for (auto size : llvm::enumerate(type.getShape())) {
// If the size is statically known to be 1, we don't want any
// error guards to be spuriously emitted, since we are specifically
// allowing size-1 broadcasts in this case, as they correspond to a
// constant-0 indexing map.
if (size.value() == 1) {
exprs.push_back(rewriter.getAffineConstantExpr(0));
continue;
}
// The rank of this operand might be smaller than the overall rank of
// the broadcast. Add an offset to correlate it to the correct
// dimension of the result.
auto resultDim = size.index() + (resultRank - type.getRank());
// The generated linalg op will now be iterating along the full size
// of this dimension. Record that fact.
exprs.push_back(rewriter.getAffineDimExpr(resultDim));
// Now, we need to ensure that such iteration is not going to trigger
// undefined behavior, by doing appropriate checks against the current
// dimension size.
auto currentDimSize =
getDimOp(rewriter, loc, tensorOperand, size.index());
// If the result size of this dimension has so far only hit the
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
// new Value to `resultShape[resultDim]`), then we have no other dynamic
// values to check against, and merely need to record the current
// dimension size.
if (resultShape[resultDim] == c1) {
resultShape[resultDim] = currentDimSize;
continue;
}
// We prohibit the size-1 dynamic broadcasting scenario, so just check
// for exact equality with the running result size.
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
currentDimSize);
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
}
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
// Add the indexing map for the outs init tensor.
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(resultShape), resultType.getElementType());
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/initTensor.getType(),
/*inputs=*/tensorOperands,
/*outputs=*/initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
Value generic = createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForElementwiseOp(
b, loc, getTypeConverter(), payloadArgs, op, operands);
@ -2650,8 +2744,7 @@ struct ConvertElementwiseOp : ConversionPattern {
});
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
generic.getResult(0));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
return success();
}
};
@ -2662,104 +2755,19 @@ struct ConvertReductionOp : ConversionPattern {
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
// This function is in charge of all the rewriting that will take
// place in `matchAndRewrite`. In particular, it converts
// the reduce operation into an `linalg.generic` operation
// to reduce the input tensor along the dimensions specified in
// `dimeSet`.
LogicalResult
createReductionLinalgGeneric(Operation *op, ArrayRef<Value> operands,
const DenseSet<int64_t> &dimSet, bool keepDim,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
// Get the result shape by obtaining the size of each
// dimension in the input tensor that is not getting reduced.
// If `keepDim` is true, the rank of the output tensor
// is kept the same as the rank of the input tensor, and the
// reduced dimensions are set to have size 1.
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
auto currentDimSize =
rewriter.create<tensor::DimOp>(loc, tensorOperand, i);
if (!dimSet.contains(i))
resultShape.push_back(currentDimSize);
else if (keepDim)
resultShape.push_back(c1);
}
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
if (dimSet.contains(size.index())) {
iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(getParallelIteratorTypeName());
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
}
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultShape, resultType.getElementType());
Value initValue = createLinalgNeutralElementForReduceOp(
rewriter, loc, op, resultType.getElementType());
Value accumulator =
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
.getResult(0);
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/accumulator.getType(),
/*inputs=*/tensorOperand,
/*outputs=*/accumulator,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForReduceOp(
b, loc, payloadArgs, op, operands, resultType.getElementType());
if (!result) {
hadErrorCreatingPayload = true;
return;
}
b.create<linalg::YieldOp>(loc, result);
});
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
generic.getResult(0));
return success();
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// Every reduce operation must set a value for the `dimSet` and
// `keepDim` in accordance with their specification.
// Every reduce operation must set a value for the `dimSet`,
// `tensorOperand`, and `keepDim` in accordance with their specification.
DenseSet<int64_t> dimSet;
Value tensorOperand;
bool keepDim = false;
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
auto tensorOperand = operands[0];
tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
@ -2767,7 +2775,7 @@ struct ConvertReductionOp : ConversionPattern {
for (int64_t i = 0; i < inputType.getRank(); i++)
dimSet.insert(i);
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
auto tensorOperand = operands[0];
tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
if (!matchPattern(sumDimIntListOp.keepdim(),
@ -2788,8 +2796,31 @@ struct ConvertReductionOp : ConversionPattern {
} else {
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
rewriter);
Location loc = op->getLoc();
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Value initElem = createLinalgNeutralElementForReduceOp(
rewriter, loc, op, resultType.getElementType());
bool hadErrorCreatingPayload = false;
Value generic = createReductionLinalgGeneric(
rewriter, loc, tensorOperand, dimSet, keepDim, initElem,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForReduceOp(
b, loc, payloadArgs, op, operands, resultType.getElementType());
if (!result) {
hadErrorCreatingPayload = true;
return;
}
b.create<linalg::YieldOp>(loc, result);
});
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
return success();
}
};
} // namespace
@ -4251,14 +4282,8 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.self();
SmallVector<Value> sizes(getTensorSizes(rewriter, loc, self));
Value productResult =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
for (size_t i = 0; i < sizes.size(); i++)
productResult =
rewriter.create<arith::MulIOp>(loc, productResult, sizes[i]);
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, productResult));
Value tensorSize = getTensorSize(rewriter, loc, adaptor.self());
rewriter.replaceOp(op, tensorSize);
return success();
}
};