mirror of https://github.com/llvm/torch-mlir
Add `reduction` support to `torch.nll_loss_forward` (#624)
This commit does a couple of things. First, it fixes a bug in the `linalg.generic` body of the `nll_loss_forward` lowering where the `ignoreIndex` was being compared with the loop index rather than the current element of the `target` tensor. This was not being caught by the tests because they were not testing the case where `ingnoreIndex` actually corresponds to a value in `target`. This has been fixed. Second, this commit adds support for the `reduction` argument in `torch.nll_loss_forward` as well as support for 1-D inputs. In order to simplify the lowering code, I've refactored the code that creates the `linalg.generic` ops for elementwise and reduction ops into static functions, to avoid having boilerplate code for indexing maps, etc that can be very error prone. Note: The function `convertScalarToDtype` was moved to before all the conversion patterns, but nothing in it was modified.pull/628/head
parent
9b2613533b
commit
58abec5c0a
|
@ -34,7 +34,79 @@ class NllLossModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
def NllLossModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
|
||||
|
||||
class NllLossModule_mean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||
def NllLossModule_mean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
|
||||
|
||||
class NllLossModule_sum(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||
def NllLossModule_sum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
|
||||
|
||||
class NllLossModule_1D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||
def NllLossModule_1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), torch.randint(0, 3, ()))
|
||||
|
||||
|
||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||
|
@ -58,8 +130,8 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
|||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
def NllLossModule_ignore_index(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
||||
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
|
||||
class NllLossModule_backward(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -181,6 +181,14 @@ static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
|
|||
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||
}
|
||||
|
||||
static Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
|
||||
SmallVector<Value> sizes(getTensorSizes(b, loc, tensor));
|
||||
Value productResult = b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
||||
for (Value size : sizes)
|
||||
productResult = b.create<arith::MulIOp>(loc, productResult, size);
|
||||
return castIndexToInt(b, loc, productResult);
|
||||
}
|
||||
|
||||
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||
Type elemTy) {
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||
|
@ -333,6 +341,232 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
|||
return buildNormalCdf(b, loc, x, zero, one);
|
||||
}
|
||||
|
||||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
||||
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||
// be able to know if we need signed or unsigned conversion.
|
||||
auto isByteOrChar = [](Type type) {
|
||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||
return integerTy.getWidth() == 8;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||
dtype.isSignlessInteger(1)) {
|
||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||
}
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
// Create a reduction of `tensorOperand`, reducing along the dimensions
|
||||
// in `dimSet`. If `keepDim` is true, the output tensor is the same
|
||||
// rank as the `tensorOperand` and reduced dimensions are set to size 1.
|
||||
// `initElem` is the element used to initialize the output tensor
|
||||
// where the reduction will be stored.
|
||||
static Value createReductionLinalgGeneric(
|
||||
OpBuilder &b, Location loc, Value tensorOperand,
|
||||
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// Get the result shape by obtaining the size of each
|
||||
// dimension in the input tensor that is not getting reduced.
|
||||
// If `keepDim` is true, the rank of the output tensor
|
||||
// is kept the same as the rank of the input tensor, and the
|
||||
// reduced dimensions are set to have size 1.
|
||||
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||
auto currentDimSize = b.create<tensor::DimOp>(loc, tensorOperand, i);
|
||||
if (!dimSet.contains(i))
|
||||
resultShape.push_back(currentDimSize);
|
||||
else if (keepDim)
|
||||
resultShape.push_back(c1);
|
||||
}
|
||||
|
||||
// Create the affine expressions that will be used to
|
||||
// iterate over the input and output tensors.
|
||||
// Here we also set the type of iterator: parallel or reduction.
|
||||
SmallVector<AffineExpr> exprs;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
SmallVector<AffineExpr> resultExprs;
|
||||
for (auto size : llvm::enumerate(inputType.getShape())) {
|
||||
exprs.push_back(b.getAffineDimExpr(size.index()));
|
||||
|
||||
if (dimSet.contains(size.index())) {
|
||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||
// If `keepDim`, create affine map to the first element
|
||||
// in the current dimension.
|
||||
if (keepDim)
|
||||
resultExprs.push_back(b.getAffineConstantExpr(0));
|
||||
} else {
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
resultExprs.push_back(b.getAffineDimExpr(size.index()));
|
||||
}
|
||||
}
|
||||
|
||||
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
||||
Value accumulator =
|
||||
createInitTensor(b, loc, resultShape, initElem.getType(), initElem);
|
||||
|
||||
return b
|
||||
.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||
/*inputs=*/tensorOperand,
|
||||
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
static Value createElementwiseLinalgGeneric(
|
||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||
// The overall error handling strategy here is best viewed by thinking about
|
||||
// what happens for a single result dimension. This loop not structured that
|
||||
// way because it is hard to create the affine maps for each operand unless
|
||||
// we structure the loop to iterate over tensor operands as the outer loop
|
||||
// instead of inner loop. This pseudocode gives better intuition:
|
||||
// ```
|
||||
// for each result dimension:
|
||||
// for each tensor operand:
|
||||
// if it doesn't even have high enough rank relative to the result:
|
||||
// continue
|
||||
// if it is a static size-1 along this result dimension:
|
||||
// continue
|
||||
// if this is the first tensor operand that didn't continue above:
|
||||
// take its dimension size as the size of the non-broadcasted
|
||||
// traversal along this dimension (this may include a dynamic size-1,
|
||||
// **non-broadcasted** traversal!)
|
||||
// emit error check "if the size does not match the non-broadcasted
|
||||
// traversal size along this dimension, error"
|
||||
// ```
|
||||
SmallVector<int64_t> operandRanks;
|
||||
operandRanks.resize(tensorOperands.size());
|
||||
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
||||
});
|
||||
|
||||
auto resultRankIt =
|
||||
std::max_element(operandRanks.begin(), operandRanks.end());
|
||||
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
|
||||
int64_t resultRank = *resultRankIt;
|
||||
|
||||
// Initialize the resultShape to all 1's, as a fallback in case
|
||||
// all sizes along that result dimension are statically 1.
|
||||
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape(resultRank, c1);
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (Value tensorOperand : tensorOperands) {
|
||||
SmallVector<AffineExpr> exprs;
|
||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||
for (auto size : llvm::enumerate(type.getShape())) {
|
||||
// If the size is statically known to be 1, we don't want any
|
||||
// error guards to be spuriously emitted, since we are specifically
|
||||
// allowing size-1 broadcasts in this case, as they correspond to a
|
||||
// constant-0 indexing map.
|
||||
if (size.value() == 1) {
|
||||
exprs.push_back(b.getAffineConstantExpr(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The rank of this operand might be smaller than the overall rank of
|
||||
// the broadcast. Add an offset to correlate it to the correct
|
||||
// dimension of the result.
|
||||
auto resultDim = size.index() + (resultRank - type.getRank());
|
||||
|
||||
// The generated linalg op will now be iterating along the full size
|
||||
// of this dimension. Record that fact.
|
||||
exprs.push_back(b.getAffineDimExpr(resultDim));
|
||||
|
||||
// Now, we need to ensure that such iteration is not going to trigger
|
||||
// undefined behavior, by doing appropriate checks against the current
|
||||
// dimension size.
|
||||
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
|
||||
|
||||
// If the result size of this dimension has so far only hit the
|
||||
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
||||
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
||||
// values to check against, and merely need to record the current
|
||||
// dimension size.
|
||||
if (resultShape[resultDim] == c1) {
|
||||
resultShape[resultDim] = currentDimSize;
|
||||
continue;
|
||||
}
|
||||
|
||||
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
||||
// for exact equality with the running result size.
|
||||
// This is the check which protects against the undefined behavior of
|
||||
// the generated linalg op in the case of iterating two operands with
|
||||
// dimensions sizes that are expected to match.
|
||||
auto equalToRunning =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
resultShape[resultDim], currentDimSize);
|
||||
b.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
indexingMaps.push_back(AffineMap::get(
|
||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||
getParallelIteratorTypeName());
|
||||
// Add the indexing map for the outs init tensor.
|
||||
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
|
||||
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, getAsOpFoldResult(resultShape), resultElementType);
|
||||
return b
|
||||
.create<linalg::GenericOp>(loc,
|
||||
/*resultTensorTypes=*/initTensor.getType(),
|
||||
/*inputs=*/tensorOperands,
|
||||
/*outputs=*/initTensor, indexingMaps,
|
||||
iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool2dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||
|
@ -1237,7 +1471,7 @@ public:
|
|||
// for i in range(0, len(target)):
|
||||
// indi = target[i];
|
||||
// nll_loss_forward[i] = -(input[i][indi]);
|
||||
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
||||
// TODO: `weight`operand is still to be taken care of.
|
||||
namespace {
|
||||
class ConvertAtenNllLossForwardOp
|
||||
: public OpConversionPattern<AtenNllLossForwardOp> {
|
||||
|
@ -1253,15 +1487,10 @@ public:
|
|||
Value target = adaptor.target();
|
||||
Value weight = adaptor.weight();
|
||||
|
||||
int64_t reduce_dim;
|
||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
|
||||
int64_t reduction;
|
||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
|
||||
// TODO: Handle reduction.
|
||||
if (reduce_dim != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "reduction along dimensions is not supported.");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1273,52 +1502,67 @@ public:
|
|||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||
|
||||
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
|
||||
if (inputRank != 2 || targetRank != 1) {
|
||||
// TODO: Add support for k-dim loss.
|
||||
if (inputRank > 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input and target to be rank 2 and 1 respectively");
|
||||
op, "expected input and target to be rank <= 2");
|
||||
}
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
Value targetDim = getDimOp(rewriter, loc, target, 0);
|
||||
Value initTensor0 =
|
||||
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
|
||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
|
||||
SmallVector<AffineExpr> targetExpr;
|
||||
targetExpr.push_back(rewriter.getAffineDimExpr(0));
|
||||
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
|
||||
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor0.getType(), ValueRange{target}, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
Value finalRes = createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {target}, elementType,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value targetVal = args[0];
|
||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), args[0]);
|
||||
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||
loc, rewriter.getIndexType(), targetVal);
|
||||
|
||||
// The final result is given by:
|
||||
// final_res = (indI == ignoreIndexVal) ? 0 :
|
||||
// final_res = (indTarget == ignoreIndexVal) ? 0 :
|
||||
// input[indI][IndTarget]
|
||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
|
||||
Value result = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{indI, indTarget});
|
||||
loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal);
|
||||
|
||||
SmallVector<Value> extractionIndices{indTarget};
|
||||
if (inputRank == 2) {
|
||||
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||
extractionIndices.insert(extractionIndices.begin(), indI);
|
||||
}
|
||||
|
||||
Value result =
|
||||
rewriter.create<tensor::ExtractOp>(loc, input, extractionIndices);
|
||||
|
||||
Value negate =
|
||||
rewriter.create<arith::NegFOp>(loc, elementType, result);
|
||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
||||
loc, cmpEq, zeroVal, negate);
|
||||
Value selectFinal =
|
||||
rewriter.create<arith::SelectOp>(loc, cmpEq, zeroVal, negate);
|
||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
})
|
||||
.getResult(0);
|
||||
});
|
||||
|
||||
if (reduction == Reduction::Sum || reduction == Reduction::Mean) {
|
||||
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
|
||||
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
|
||||
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
|
||||
/*inclusive=*/false);
|
||||
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
|
||||
|
||||
finalRes = createReductionLinalgGeneric(
|
||||
rewriter, loc, finalRes, dimSet, /*keepDim=*/false,
|
||||
/*initElem=*/zeroVal,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value newVal = args[0];
|
||||
Value accumulator = args[1];
|
||||
if (reduction == Reduction::Mean)
|
||||
newVal = b.create<arith::DivFOp>(loc, newVal, numOfElems);
|
||||
Value result = b.create<arith::AddFOp>(loc, newVal, accumulator);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: Update the second result tensor.
|
||||
Value weightUpdated =
|
||||
|
@ -1588,66 +1832,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||
Type dtype) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
||||
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||
// be able to know if we need signed or unsigned conversion.
|
||||
auto isByteOrChar = [](Type type) {
|
||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||
return integerTy.getWidth() == 8;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||
dtype.isSignlessInteger(1)) {
|
||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||
}
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
||||
assert(scalarType.isa<mlir::IntegerType>());
|
||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||
if (scalarType.isSignlessInteger(1))
|
||||
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||
// unsigned handling is needed, and we checked for that case above.
|
||||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
|
@ -2546,99 +2730,9 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto resultRank = resultType.getRank();
|
||||
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
// The overall error handling strategy here is best viewed by thinking about
|
||||
// what happens for a single result dimension. This loop not structured that
|
||||
// way because it is hard to create the affine maps for each operand unless
|
||||
// we structure the loop to iterate over tensor operands as the outer loop
|
||||
// instead of inner loop. This pseudocode gives better intuition:
|
||||
// ```
|
||||
// for each result dimension:
|
||||
// for each tensor operand:
|
||||
// if it doesn't even have high enough rank relative to the result:
|
||||
// continue
|
||||
// if it is a static size-1 along this result dimension:
|
||||
// continue
|
||||
// if this is the first tensor operand that didn't continue above:
|
||||
// take its dimension size as the size of the non-broadcasted
|
||||
// traversal along this dimension (this may include a dynamic size-1,
|
||||
// **non-broadcasted** traversal!)
|
||||
// emit error check "if the size does not match the non-broadcasted
|
||||
// traversal size along this dimension, error"
|
||||
// ```
|
||||
// Initialize the resultShape to all 1's, as a fallback in case
|
||||
// all sizes along that result dimension are statically 1.
|
||||
SmallVector<Value> resultShape(resultRank, c1);
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (Value tensorOperand : tensorOperands) {
|
||||
SmallVector<AffineExpr> exprs;
|
||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||
for (auto size : llvm::enumerate(type.getShape())) {
|
||||
// If the size is statically known to be 1, we don't want any
|
||||
// error guards to be spuriously emitted, since we are specifically
|
||||
// allowing size-1 broadcasts in this case, as they correspond to a
|
||||
// constant-0 indexing map.
|
||||
if (size.value() == 1) {
|
||||
exprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The rank of this operand might be smaller than the overall rank of
|
||||
// the broadcast. Add an offset to correlate it to the correct
|
||||
// dimension of the result.
|
||||
auto resultDim = size.index() + (resultRank - type.getRank());
|
||||
|
||||
// The generated linalg op will now be iterating along the full size
|
||||
// of this dimension. Record that fact.
|
||||
exprs.push_back(rewriter.getAffineDimExpr(resultDim));
|
||||
|
||||
// Now, we need to ensure that such iteration is not going to trigger
|
||||
// undefined behavior, by doing appropriate checks against the current
|
||||
// dimension size.
|
||||
auto currentDimSize =
|
||||
getDimOp(rewriter, loc, tensorOperand, size.index());
|
||||
|
||||
// If the result size of this dimension has so far only hit the
|
||||
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
||||
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
||||
// values to check against, and merely need to record the current
|
||||
// dimension size.
|
||||
if (resultShape[resultDim] == c1) {
|
||||
resultShape[resultDim] = currentDimSize;
|
||||
continue;
|
||||
}
|
||||
|
||||
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
||||
// for exact equality with the running result size.
|
||||
// This is the check which protects against the undefined behavior of
|
||||
// the generated linalg op in the case of iterating two operands with
|
||||
// dimensions sizes that are expected to match.
|
||||
auto equalToRunning = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
|
||||
currentDimSize);
|
||||
rewriter.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
indexingMaps.push_back(AffineMap::get(
|
||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||
getParallelIteratorTypeName());
|
||||
// Add the indexing map for the outs init tensor.
|
||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, getAsOpFoldResult(resultShape), resultType.getElementType());
|
||||
bool hadErrorCreatingPayload = false;
|
||||
auto generic = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/initTensor.getType(),
|
||||
/*inputs=*/tensorOperands,
|
||||
/*outputs=*/initTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
Value generic = createElementwiseLinalgGeneric(
|
||||
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
||||
b, loc, getTypeConverter(), payloadArgs, op, operands);
|
||||
|
@ -2650,8 +2744,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
});
|
||||
if (hadErrorCreatingPayload)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
generic.getResult(0));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2662,104 +2755,19 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
// This function is in charge of all the rewriting that will take
|
||||
// place in `matchAndRewrite`. In particular, it converts
|
||||
// the reduce operation into an `linalg.generic` operation
|
||||
// to reduce the input tensor along the dimensions specified in
|
||||
// `dimeSet`.
|
||||
LogicalResult
|
||||
createReductionLinalgGeneric(Operation *op, ArrayRef<Value> operands,
|
||||
const DenseSet<int64_t> &dimSet, bool keepDim,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
auto tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
// Get the result shape by obtaining the size of each
|
||||
// dimension in the input tensor that is not getting reduced.
|
||||
// If `keepDim` is true, the rank of the output tensor
|
||||
// is kept the same as the rank of the input tensor, and the
|
||||
// reduced dimensions are set to have size 1.
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
SmallVector<Value> resultShape;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||
auto currentDimSize =
|
||||
rewriter.create<tensor::DimOp>(loc, tensorOperand, i);
|
||||
if (!dimSet.contains(i))
|
||||
resultShape.push_back(currentDimSize);
|
||||
else if (keepDim)
|
||||
resultShape.push_back(c1);
|
||||
}
|
||||
|
||||
// Create the affine expressions that will be used to
|
||||
// iterate over the input and output tensors.
|
||||
// Here we also set the type of iterator: parallel or reduction.
|
||||
SmallVector<AffineExpr> exprs;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
SmallVector<AffineExpr> resultExprs;
|
||||
for (auto size : llvm::enumerate(inputType.getShape())) {
|
||||
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||
|
||||
if (dimSet.contains(size.index())) {
|
||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||
// If `keepDim`, create affine map to the first element
|
||||
// in the current dimension.
|
||||
if (keepDim)
|
||||
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||
}
|
||||
}
|
||||
|
||||
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultShape, resultType.getElementType());
|
||||
Value initValue = createLinalgNeutralElementForReduceOp(
|
||||
rewriter, loc, op, resultType.getElementType());
|
||||
Value accumulator =
|
||||
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
|
||||
.getResult(0);
|
||||
bool hadErrorCreatingPayload = false;
|
||||
auto generic = rewriter.create<linalg::GenericOp>(
|
||||
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||
/*inputs=*/tensorOperand,
|
||||
/*outputs=*/accumulator,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForReduceOp(
|
||||
b, loc, payloadArgs, op, operands, resultType.getElementType());
|
||||
if (!result) {
|
||||
hadErrorCreatingPayload = true;
|
||||
return;
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
if (hadErrorCreatingPayload)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
generic.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// Every reduce operation must set a value for the `dimSet` and
|
||||
// `keepDim` in accordance with their specification.
|
||||
// Every reduce operation must set a value for the `dimSet`,
|
||||
// `tensorOperand`, and `keepDim` in accordance with their specification.
|
||||
DenseSet<int64_t> dimSet;
|
||||
Value tensorOperand;
|
||||
bool keepDim = false;
|
||||
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
|
||||
auto tensorOperand = operands[0];
|
||||
tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
||||
|
@ -2767,7 +2775,7 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
dimSet.insert(i);
|
||||
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
auto tensorOperand = operands[0];
|
||||
tensorOperand = operands[0];
|
||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!matchPattern(sumDimIntListOp.keepdim(),
|
||||
|
@ -2788,8 +2796,31 @@ struct ConvertReductionOp : ConversionPattern {
|
|||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||
}
|
||||
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
|
||||
rewriter);
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Value initElem = createLinalgNeutralElementForReduceOp(
|
||||
rewriter, loc, op, resultType.getElementType());
|
||||
|
||||
bool hadErrorCreatingPayload = false;
|
||||
Value generic = createReductionLinalgGeneric(
|
||||
rewriter, loc, tensorOperand, dimSet, keepDim, initElem,
|
||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||
Value result = createLinalgPayloadCalculationForReduceOp(
|
||||
b, loc, payloadArgs, op, operands, resultType.getElementType());
|
||||
if (!result) {
|
||||
hadErrorCreatingPayload = true;
|
||||
return;
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
if (hadErrorCreatingPayload)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -4251,14 +4282,8 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.self();
|
||||
SmallVector<Value> sizes(getTensorSizes(rewriter, loc, self));
|
||||
Value productResult =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
for (size_t i = 0; i < sizes.size(); i++)
|
||||
productResult =
|
||||
rewriter.create<arith::MulIOp>(loc, productResult, sizes[i]);
|
||||
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, productResult));
|
||||
Value tensorSize = getTensorSize(rewriter, loc, adaptor.self());
|
||||
rewriter.replaceOp(op, tensorSize);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue