mirror of https://github.com/llvm/torch-mlir
Add `reduction` support to `torch.nll_loss_forward` (#624)
This commit does a couple of things. First, it fixes a bug in the `linalg.generic` body of the `nll_loss_forward` lowering where the `ignoreIndex` was being compared with the loop index rather than the current element of the `target` tensor. This was not being caught by the tests because they were not testing the case where `ingnoreIndex` actually corresponds to a value in `target`. This has been fixed. Second, this commit adds support for the `reduction` argument in `torch.nll_loss_forward` as well as support for 1-D inputs. In order to simplify the lowering code, I've refactored the code that creates the `linalg.generic` ops for elementwise and reduction ops into static functions, to avoid having boilerplate code for indexing maps, etc that can be very error prone. Note: The function `convertScalarToDtype` was moved to before all the conversion patterns, but nothing in it was modified.pull/628/head
parent
9b2613533b
commit
58abec5c0a
|
@ -34,7 +34,79 @@ class NllLossModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule())
|
@register_test_case(module_factory=lambda: NllLossModule())
|
||||||
def NllLossModule_basic(module, tu: TestUtils):
|
def NllLossModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_mean(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(x,
|
||||||
|
target=y,
|
||||||
|
weight=None,
|
||||||
|
reduction=1,
|
||||||
|
ignore_index=2)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||||
|
def NllLossModule_mean_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_sum(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(x,
|
||||||
|
target=y,
|
||||||
|
weight=None,
|
||||||
|
reduction=2,
|
||||||
|
ignore_index=2)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||||
|
def NllLossModule_sum_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_1D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.int64, True),
|
||||||
|
])
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(x,
|
||||||
|
target=y,
|
||||||
|
weight=None,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=2)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||||
|
def NllLossModule_1D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3), torch.randint(0, 3, ()))
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||||
|
@ -58,8 +130,8 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||||
def NllLossModule_ignore_index(module, tu: TestUtils):
|
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||||
|
|
||||||
class NllLossModule_backward(torch.nn.Module):
|
class NllLossModule_backward(torch.nn.Module):
|
||||||
|
|
||||||
|
|
|
@ -181,6 +181,14 @@ static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
|
||||||
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
|
||||||
|
SmallVector<Value> sizes(getTensorSizes(b, loc, tensor));
|
||||||
|
Value productResult = b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
||||||
|
for (Value size : sizes)
|
||||||
|
productResult = b.create<arith::MulIOp>(loc, productResult, size);
|
||||||
|
return castIndexToInt(b, loc, productResult);
|
||||||
|
}
|
||||||
|
|
||||||
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
Type elemTy) {
|
Type elemTy) {
|
||||||
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||||
|
@ -333,6 +341,232 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
||||||
return buildNormalCdf(b, loc, x, zero, one);
|
return buildNormalCdf(b, loc, x, zero, one);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert a scalar value to the target type. The scalar value can be an element
|
||||||
|
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||||
|
// should be converted builtin types.
|
||||||
|
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
|
Type dtype) {
|
||||||
|
Type scalarType = scalar.getType();
|
||||||
|
if (scalarType == dtype)
|
||||||
|
return scalar;
|
||||||
|
|
||||||
|
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
||||||
|
// be able to know if we need signed or unsigned conversion.
|
||||||
|
auto isByteOrChar = [](Type type) {
|
||||||
|
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
||||||
|
return integerTy.getWidth() == 8;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||||
|
dtype.isSignlessInteger(1)) {
|
||||||
|
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||||
|
mlir::emitError(loc)
|
||||||
|
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||||
|
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||||
|
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||||
|
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||||
|
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
||||||
|
// Only scalarFloat width < dtypeFloat width can reach here.
|
||||||
|
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
||||||
|
}
|
||||||
|
assert(scalarType.isa<mlir::IntegerType>());
|
||||||
|
if (scalarType.isSignlessInteger(1))
|
||||||
|
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
||||||
|
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
||||||
|
// unsigned handling is needed, and we checked for that case above.
|
||||||
|
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
|
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
||||||
|
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
||||||
|
assert(scalarType.isa<mlir::IntegerType>());
|
||||||
|
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
||||||
|
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
||||||
|
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
||||||
|
if (scalarType.isSignlessInteger(1))
|
||||||
|
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
||||||
|
// Only scalarInteger width < dtypeInteger width can reach here.
|
||||||
|
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
||||||
|
// unsigned handling is needed, and we checked for that case above.
|
||||||
|
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a reduction of `tensorOperand`, reducing along the dimensions
|
||||||
|
// in `dimSet`. If `keepDim` is true, the output tensor is the same
|
||||||
|
// rank as the `tensorOperand` and reduced dimensions are set to size 1.
|
||||||
|
// `initElem` is the element used to initialize the output tensor
|
||||||
|
// where the reduction will be stored.
|
||||||
|
static Value createReductionLinalgGeneric(
|
||||||
|
OpBuilder &b, Location loc, Value tensorOperand,
|
||||||
|
const DenseSet<int64_t> &dimSet, bool keepDim, Value initElem,
|
||||||
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||||
|
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
// Get the result shape by obtaining the size of each
|
||||||
|
// dimension in the input tensor that is not getting reduced.
|
||||||
|
// If `keepDim` is true, the rank of the output tensor
|
||||||
|
// is kept the same as the rank of the input tensor, and the
|
||||||
|
// reduced dimensions are set to have size 1.
|
||||||
|
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||||
|
auto currentDimSize = b.create<tensor::DimOp>(loc, tensorOperand, i);
|
||||||
|
if (!dimSet.contains(i))
|
||||||
|
resultShape.push_back(currentDimSize);
|
||||||
|
else if (keepDim)
|
||||||
|
resultShape.push_back(c1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the affine expressions that will be used to
|
||||||
|
// iterate over the input and output tensors.
|
||||||
|
// Here we also set the type of iterator: parallel or reduction.
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
SmallVector<StringRef> iteratorTypes;
|
||||||
|
SmallVector<AffineExpr> resultExprs;
|
||||||
|
for (auto size : llvm::enumerate(inputType.getShape())) {
|
||||||
|
exprs.push_back(b.getAffineDimExpr(size.index()));
|
||||||
|
|
||||||
|
if (dimSet.contains(size.index())) {
|
||||||
|
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||||
|
// If `keepDim`, create affine map to the first element
|
||||||
|
// in the current dimension.
|
||||||
|
if (keepDim)
|
||||||
|
resultExprs.push_back(b.getAffineConstantExpr(0));
|
||||||
|
} else {
|
||||||
|
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||||
|
resultExprs.push_back(b.getAffineDimExpr(size.index()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
||||||
|
Value accumulator =
|
||||||
|
createInitTensor(b, loc, resultShape, initElem.getType(), initElem);
|
||||||
|
|
||||||
|
return b
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, /*resultTensorTypes=*/accumulator.getType(),
|
||||||
|
/*inputs=*/tensorOperand,
|
||||||
|
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createElementwiseLinalgGeneric(
|
||||||
|
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||||
|
Type resultElementType,
|
||||||
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||||
|
// The overall error handling strategy here is best viewed by thinking about
|
||||||
|
// what happens for a single result dimension. This loop not structured that
|
||||||
|
// way because it is hard to create the affine maps for each operand unless
|
||||||
|
// we structure the loop to iterate over tensor operands as the outer loop
|
||||||
|
// instead of inner loop. This pseudocode gives better intuition:
|
||||||
|
// ```
|
||||||
|
// for each result dimension:
|
||||||
|
// for each tensor operand:
|
||||||
|
// if it doesn't even have high enough rank relative to the result:
|
||||||
|
// continue
|
||||||
|
// if it is a static size-1 along this result dimension:
|
||||||
|
// continue
|
||||||
|
// if this is the first tensor operand that didn't continue above:
|
||||||
|
// take its dimension size as the size of the non-broadcasted
|
||||||
|
// traversal along this dimension (this may include a dynamic size-1,
|
||||||
|
// **non-broadcasted** traversal!)
|
||||||
|
// emit error check "if the size does not match the non-broadcasted
|
||||||
|
// traversal size along this dimension, error"
|
||||||
|
// ```
|
||||||
|
SmallVector<int64_t> operandRanks;
|
||||||
|
operandRanks.resize(tensorOperands.size());
|
||||||
|
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||||
|
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
||||||
|
});
|
||||||
|
|
||||||
|
auto resultRankIt =
|
||||||
|
std::max_element(operandRanks.begin(), operandRanks.end());
|
||||||
|
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
|
||||||
|
int64_t resultRank = *resultRankIt;
|
||||||
|
|
||||||
|
// Initialize the resultShape to all 1's, as a fallback in case
|
||||||
|
// all sizes along that result dimension are statically 1.
|
||||||
|
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||||
|
SmallVector<Value> resultShape(resultRank, c1);
|
||||||
|
SmallVector<AffineMap> indexingMaps;
|
||||||
|
for (Value tensorOperand : tensorOperands) {
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
for (auto size : llvm::enumerate(type.getShape())) {
|
||||||
|
// If the size is statically known to be 1, we don't want any
|
||||||
|
// error guards to be spuriously emitted, since we are specifically
|
||||||
|
// allowing size-1 broadcasts in this case, as they correspond to a
|
||||||
|
// constant-0 indexing map.
|
||||||
|
if (size.value() == 1) {
|
||||||
|
exprs.push_back(b.getAffineConstantExpr(0));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The rank of this operand might be smaller than the overall rank of
|
||||||
|
// the broadcast. Add an offset to correlate it to the correct
|
||||||
|
// dimension of the result.
|
||||||
|
auto resultDim = size.index() + (resultRank - type.getRank());
|
||||||
|
|
||||||
|
// The generated linalg op will now be iterating along the full size
|
||||||
|
// of this dimension. Record that fact.
|
||||||
|
exprs.push_back(b.getAffineDimExpr(resultDim));
|
||||||
|
|
||||||
|
// Now, we need to ensure that such iteration is not going to trigger
|
||||||
|
// undefined behavior, by doing appropriate checks against the current
|
||||||
|
// dimension size.
|
||||||
|
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
|
||||||
|
|
||||||
|
// If the result size of this dimension has so far only hit the
|
||||||
|
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
||||||
|
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
||||||
|
// values to check against, and merely need to record the current
|
||||||
|
// dimension size.
|
||||||
|
if (resultShape[resultDim] == c1) {
|
||||||
|
resultShape[resultDim] = currentDimSize;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
||||||
|
// for exact equality with the running result size.
|
||||||
|
// This is the check which protects against the undefined behavior of
|
||||||
|
// the generated linalg op in the case of iterating two operands with
|
||||||
|
// dimensions sizes that are expected to match.
|
||||||
|
auto equalToRunning =
|
||||||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||||
|
resultShape[resultDim], currentDimSize);
|
||||||
|
b.create<cf::AssertOp>(loc, equalToRunning,
|
||||||
|
"mismatched size for broadcast");
|
||||||
|
}
|
||||||
|
indexingMaps.push_back(AffineMap::get(
|
||||||
|
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
// Add the indexing map for the outs init tensor.
|
||||||
|
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
|
||||||
|
|
||||||
|
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||||
|
loc, getAsOpFoldResult(resultShape), resultElementType);
|
||||||
|
return b
|
||||||
|
.create<linalg::GenericOp>(loc,
|
||||||
|
/*resultTensorTypes=*/initTensor.getType(),
|
||||||
|
/*inputs=*/tensorOperands,
|
||||||
|
/*outputs=*/initTensor, indexingMaps,
|
||||||
|
iteratorTypes, bodyBuild)
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenAdaptiveAvgPool2dOp
|
class ConvertAtenAdaptiveAvgPool2dOp
|
||||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||||
|
@ -1237,7 +1471,7 @@ public:
|
||||||
// for i in range(0, len(target)):
|
// for i in range(0, len(target)):
|
||||||
// indi = target[i];
|
// indi = target[i];
|
||||||
// nll_loss_forward[i] = -(input[i][indi]);
|
// nll_loss_forward[i] = -(input[i][indi]);
|
||||||
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
// TODO: `weight`operand is still to be taken care of.
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenNllLossForwardOp
|
class ConvertAtenNllLossForwardOp
|
||||||
: public OpConversionPattern<AtenNllLossForwardOp> {
|
: public OpConversionPattern<AtenNllLossForwardOp> {
|
||||||
|
@ -1253,15 +1487,10 @@ public:
|
||||||
Value target = adaptor.target();
|
Value target = adaptor.target();
|
||||||
Value weight = adaptor.weight();
|
Value weight = adaptor.weight();
|
||||||
|
|
||||||
int64_t reduce_dim;
|
int64_t reduction;
|
||||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
|
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
|
|
||||||
// TODO: Handle reduction.
|
|
||||||
if (reduce_dim != 0)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "reduction along dimensions is not supported.");
|
|
||||||
|
|
||||||
// TODO: Incorporate the weight argument.
|
// TODO: Incorporate the weight argument.
|
||||||
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1273,52 +1502,67 @@ public:
|
||||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||||
|
|
||||||
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
|
// TODO: Add support for k-dim loss.
|
||||||
if (inputRank != 2 || targetRank != 1) {
|
if (inputRank > 2) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input and target to be rank 2 and 1 respectively");
|
op, "expected input and target to be rank <= 2");
|
||||||
}
|
}
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = getTypeConverter()
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
|
|
||||||
Value targetDim = getDimOp(rewriter, loc, target, 0);
|
|
||||||
Value initTensor0 =
|
|
||||||
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
|
|
||||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getZeroAttr(elementType));
|
loc, rewriter.getZeroAttr(elementType));
|
||||||
|
|
||||||
SmallVector<AffineExpr> targetExpr;
|
Value finalRes = createElementwiseLinalgGeneric(
|
||||||
targetExpr.push_back(rewriter.getAffineDimExpr(0));
|
rewriter, loc, {target}, elementType,
|
||||||
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
|
|
||||||
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
|
|
||||||
Value finalRes =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initTensor0.getType(), ValueRange{target}, initTensor0,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value targetVal = args[0];
|
||||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||||
loc, rewriter.getIndexType(), args[0]);
|
loc, rewriter.getIndexType(), targetVal);
|
||||||
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
|
|
||||||
|
|
||||||
// The final result is given by:
|
// The final result is given by:
|
||||||
// final_res = (indI == ignoreIndexVal) ? 0 :
|
// final_res = (indTarget == ignoreIndexVal) ? 0 :
|
||||||
// input[indI][IndTarget]
|
// input[indI][IndTarget]
|
||||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||||
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
|
loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal);
|
||||||
Value result = rewriter.create<tensor::ExtractOp>(
|
|
||||||
loc, input, ValueRange{indI, indTarget});
|
SmallVector<Value> extractionIndices{indTarget};
|
||||||
|
if (inputRank == 2) {
|
||||||
|
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||||
|
extractionIndices.insert(extractionIndices.begin(), indI);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result =
|
||||||
|
rewriter.create<tensor::ExtractOp>(loc, input, extractionIndices);
|
||||||
|
|
||||||
Value negate =
|
Value negate =
|
||||||
rewriter.create<arith::NegFOp>(loc, elementType, result);
|
rewriter.create<arith::NegFOp>(loc, elementType, result);
|
||||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
Value selectFinal =
|
||||||
loc, cmpEq, zeroVal, negate);
|
rewriter.create<arith::SelectOp>(loc, cmpEq, zeroVal, negate);
|
||||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||||
})
|
});
|
||||||
.getResult(0);
|
|
||||||
|
if (reduction == Reduction::Sum || reduction == Reduction::Mean) {
|
||||||
|
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
|
||||||
|
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
|
||||||
|
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
|
||||||
|
/*inclusive=*/false);
|
||||||
|
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
|
||||||
|
|
||||||
|
finalRes = createReductionLinalgGeneric(
|
||||||
|
rewriter, loc, finalRes, dimSet, /*keepDim=*/false,
|
||||||
|
/*initElem=*/zeroVal,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value newVal = args[0];
|
||||||
|
Value accumulator = args[1];
|
||||||
|
if (reduction == Reduction::Mean)
|
||||||
|
newVal = b.create<arith::DivFOp>(loc, newVal, numOfElems);
|
||||||
|
Value result = b.create<arith::AddFOp>(loc, newVal, accumulator);
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Update the second result tensor.
|
// TODO: Update the second result tensor.
|
||||||
Value weightUpdated =
|
Value weightUpdated =
|
||||||
|
@ -1588,66 +1832,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Convert a scalar value to the target type. The scalar value can be an element
|
|
||||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
|
||||||
// should be converted builtin types.
|
|
||||||
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|
||||||
Type dtype) {
|
|
||||||
Type scalarType = scalar.getType();
|
|
||||||
if (scalarType == dtype)
|
|
||||||
return scalar;
|
|
||||||
|
|
||||||
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
|
|
||||||
// be able to know if we need signed or unsigned conversion.
|
|
||||||
auto isByteOrChar = [](Type type) {
|
|
||||||
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
|
|
||||||
return integerTy.getWidth() == 8;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
|
||||||
dtype.isSignlessInteger(1)) {
|
|
||||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
|
||||||
mlir::emitError(loc)
|
|
||||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
|
||||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
|
||||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
|
||||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
|
||||||
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
|
||||||
// Only scalarFloat width < dtypeFloat width can reach here.
|
|
||||||
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
|
||||||
}
|
|
||||||
assert(scalarType.isa<mlir::IntegerType>());
|
|
||||||
if (scalarType.isSignlessInteger(1))
|
|
||||||
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
|
||||||
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
|
||||||
// unsigned handling is needed, and we checked for that case above.
|
|
||||||
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
|
|
||||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
|
|
||||||
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
|
||||||
assert(scalarType.isa<mlir::IntegerType>());
|
|
||||||
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
|
|
||||||
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
|
||||||
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
|
||||||
if (scalarType.isSignlessInteger(1))
|
|
||||||
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
|
||||||
// Only scalarInteger width < dtypeInteger width can reach here.
|
|
||||||
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
|
||||||
// unsigned handling is needed, and we checked for that case above.
|
|
||||||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||||
|
@ -2546,99 +2730,9 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = getTypeConverter()
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto resultRank = resultType.getRank();
|
|
||||||
|
|
||||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
||||||
// The overall error handling strategy here is best viewed by thinking about
|
|
||||||
// what happens for a single result dimension. This loop not structured that
|
|
||||||
// way because it is hard to create the affine maps for each operand unless
|
|
||||||
// we structure the loop to iterate over tensor operands as the outer loop
|
|
||||||
// instead of inner loop. This pseudocode gives better intuition:
|
|
||||||
// ```
|
|
||||||
// for each result dimension:
|
|
||||||
// for each tensor operand:
|
|
||||||
// if it doesn't even have high enough rank relative to the result:
|
|
||||||
// continue
|
|
||||||
// if it is a static size-1 along this result dimension:
|
|
||||||
// continue
|
|
||||||
// if this is the first tensor operand that didn't continue above:
|
|
||||||
// take its dimension size as the size of the non-broadcasted
|
|
||||||
// traversal along this dimension (this may include a dynamic size-1,
|
|
||||||
// **non-broadcasted** traversal!)
|
|
||||||
// emit error check "if the size does not match the non-broadcasted
|
|
||||||
// traversal size along this dimension, error"
|
|
||||||
// ```
|
|
||||||
// Initialize the resultShape to all 1's, as a fallback in case
|
|
||||||
// all sizes along that result dimension are statically 1.
|
|
||||||
SmallVector<Value> resultShape(resultRank, c1);
|
|
||||||
SmallVector<AffineMap> indexingMaps;
|
|
||||||
for (Value tensorOperand : tensorOperands) {
|
|
||||||
SmallVector<AffineExpr> exprs;
|
|
||||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
|
||||||
for (auto size : llvm::enumerate(type.getShape())) {
|
|
||||||
// If the size is statically known to be 1, we don't want any
|
|
||||||
// error guards to be spuriously emitted, since we are specifically
|
|
||||||
// allowing size-1 broadcasts in this case, as they correspond to a
|
|
||||||
// constant-0 indexing map.
|
|
||||||
if (size.value() == 1) {
|
|
||||||
exprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The rank of this operand might be smaller than the overall rank of
|
|
||||||
// the broadcast. Add an offset to correlate it to the correct
|
|
||||||
// dimension of the result.
|
|
||||||
auto resultDim = size.index() + (resultRank - type.getRank());
|
|
||||||
|
|
||||||
// The generated linalg op will now be iterating along the full size
|
|
||||||
// of this dimension. Record that fact.
|
|
||||||
exprs.push_back(rewriter.getAffineDimExpr(resultDim));
|
|
||||||
|
|
||||||
// Now, we need to ensure that such iteration is not going to trigger
|
|
||||||
// undefined behavior, by doing appropriate checks against the current
|
|
||||||
// dimension size.
|
|
||||||
auto currentDimSize =
|
|
||||||
getDimOp(rewriter, loc, tensorOperand, size.index());
|
|
||||||
|
|
||||||
// If the result size of this dimension has so far only hit the
|
|
||||||
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
|
||||||
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
|
||||||
// values to check against, and merely need to record the current
|
|
||||||
// dimension size.
|
|
||||||
if (resultShape[resultDim] == c1) {
|
|
||||||
resultShape[resultDim] = currentDimSize;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
|
||||||
// for exact equality with the running result size.
|
|
||||||
// This is the check which protects against the undefined behavior of
|
|
||||||
// the generated linalg op in the case of iterating two operands with
|
|
||||||
// dimensions sizes that are expected to match.
|
|
||||||
auto equalToRunning = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
|
|
||||||
currentDimSize);
|
|
||||||
rewriter.create<cf::AssertOp>(loc, equalToRunning,
|
|
||||||
"mismatched size for broadcast");
|
|
||||||
}
|
|
||||||
indexingMaps.push_back(AffineMap::get(
|
|
||||||
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
|
||||||
getParallelIteratorTypeName());
|
|
||||||
// Add the indexing map for the outs init tensor.
|
|
||||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
|
|
||||||
|
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, getAsOpFoldResult(resultShape), resultType.getElementType());
|
|
||||||
bool hadErrorCreatingPayload = false;
|
bool hadErrorCreatingPayload = false;
|
||||||
auto generic = rewriter.create<linalg::GenericOp>(
|
Value generic = createElementwiseLinalgGeneric(
|
||||||
loc, /*resultTensorTypes=*/initTensor.getType(),
|
rewriter, loc, tensorOperands, resultType.getElementType(),
|
||||||
/*inputs=*/tensorOperands,
|
|
||||||
/*outputs=*/initTensor,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
Value result = createLinalgPayloadCalculationForElementwiseOp(
|
||||||
b, loc, getTypeConverter(), payloadArgs, op, operands);
|
b, loc, getTypeConverter(), payloadArgs, op, operands);
|
||||||
|
@ -2650,8 +2744,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
});
|
});
|
||||||
if (hadErrorCreatingPayload)
|
if (hadErrorCreatingPayload)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
|
||||||
generic.getResult(0));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2662,104 +2755,19 @@ struct ConvertReductionOp : ConversionPattern {
|
||||||
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
|
||||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||||
context) {}
|
context) {}
|
||||||
|
|
||||||
// This function is in charge of all the rewriting that will take
|
|
||||||
// place in `matchAndRewrite`. In particular, it converts
|
|
||||||
// the reduce operation into an `linalg.generic` operation
|
|
||||||
// to reduce the input tensor along the dimensions specified in
|
|
||||||
// `dimeSet`.
|
|
||||||
LogicalResult
|
|
||||||
createReductionLinalgGeneric(Operation *op, ArrayRef<Value> operands,
|
|
||||||
const DenseSet<int64_t> &dimSet, bool keepDim,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
auto tensorOperand = operands[0];
|
|
||||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
|
||||||
auto resultType = getTypeConverter()
|
|
||||||
->convertType(op->getResult(0).getType())
|
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// Get the result shape by obtaining the size of each
|
|
||||||
// dimension in the input tensor that is not getting reduced.
|
|
||||||
// If `keepDim` is true, the rank of the output tensor
|
|
||||||
// is kept the same as the rank of the input tensor, and the
|
|
||||||
// reduced dimensions are set to have size 1.
|
|
||||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
||||||
SmallVector<Value> resultShape;
|
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
|
||||||
auto currentDimSize =
|
|
||||||
rewriter.create<tensor::DimOp>(loc, tensorOperand, i);
|
|
||||||
if (!dimSet.contains(i))
|
|
||||||
resultShape.push_back(currentDimSize);
|
|
||||||
else if (keepDim)
|
|
||||||
resultShape.push_back(c1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the affine expressions that will be used to
|
|
||||||
// iterate over the input and output tensors.
|
|
||||||
// Here we also set the type of iterator: parallel or reduction.
|
|
||||||
SmallVector<AffineExpr> exprs;
|
|
||||||
SmallVector<StringRef> iteratorTypes;
|
|
||||||
SmallVector<AffineExpr> resultExprs;
|
|
||||||
for (auto size : llvm::enumerate(inputType.getShape())) {
|
|
||||||
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
|
||||||
|
|
||||||
if (dimSet.contains(size.index())) {
|
|
||||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
|
||||||
// If `keepDim`, create affine map to the first element
|
|
||||||
// in the current dimension.
|
|
||||||
if (keepDim)
|
|
||||||
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
||||||
} else {
|
|
||||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
|
||||||
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
|
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, resultShape, resultType.getElementType());
|
|
||||||
Value initValue = createLinalgNeutralElementForReduceOp(
|
|
||||||
rewriter, loc, op, resultType.getElementType());
|
|
||||||
Value accumulator =
|
|
||||||
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
|
|
||||||
.getResult(0);
|
|
||||||
bool hadErrorCreatingPayload = false;
|
|
||||||
auto generic = rewriter.create<linalg::GenericOp>(
|
|
||||||
loc, /*resultTensorTypes=*/accumulator.getType(),
|
|
||||||
/*inputs=*/tensorOperand,
|
|
||||||
/*outputs=*/accumulator,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
|
||||||
Value result = createLinalgPayloadCalculationForReduceOp(
|
|
||||||
b, loc, payloadArgs, op, operands, resultType.getElementType());
|
|
||||||
if (!result) {
|
|
||||||
hadErrorCreatingPayload = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (hadErrorCreatingPayload)
|
|
||||||
return failure();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
|
||||||
generic.getResult(0));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Every reduce operation must set a value for the `dimSet` and
|
// Every reduce operation must set a value for the `dimSet`,
|
||||||
// `keepDim` in accordance with their specification.
|
// `tensorOperand`, and `keepDim` in accordance with their specification.
|
||||||
DenseSet<int64_t> dimSet;
|
DenseSet<int64_t> dimSet;
|
||||||
|
Value tensorOperand;
|
||||||
bool keepDim = false;
|
bool keepDim = false;
|
||||||
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
|
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
|
||||||
auto tensorOperand = operands[0];
|
tensorOperand = operands[0];
|
||||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
|
||||||
|
@ -2767,7 +2775,7 @@ struct ConvertReductionOp : ConversionPattern {
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
dimSet.insert(i);
|
dimSet.insert(i);
|
||||||
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
|
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||||
auto tensorOperand = operands[0];
|
tensorOperand = operands[0];
|
||||||
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!matchPattern(sumDimIntListOp.keepdim(),
|
if (!matchPattern(sumDimIntListOp.keepdim(),
|
||||||
|
@ -2788,8 +2796,31 @@ struct ConvertReductionOp : ConversionPattern {
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
|
||||||
}
|
}
|
||||||
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
|
|
||||||
rewriter);
|
Location loc = op->getLoc();
|
||||||
|
auto resultType = getTypeConverter()
|
||||||
|
->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
Value initElem = createLinalgNeutralElementForReduceOp(
|
||||||
|
rewriter, loc, op, resultType.getElementType());
|
||||||
|
|
||||||
|
bool hadErrorCreatingPayload = false;
|
||||||
|
Value generic = createReductionLinalgGeneric(
|
||||||
|
rewriter, loc, tensorOperand, dimSet, keepDim, initElem,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
|
Value result = createLinalgPayloadCalculationForReduceOp(
|
||||||
|
b, loc, payloadArgs, op, operands, resultType.getElementType());
|
||||||
|
if (!result) {
|
||||||
|
hadErrorCreatingPayload = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
b.create<linalg::YieldOp>(loc, result);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (hadErrorCreatingPayload)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -4251,14 +4282,8 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = adaptor.self();
|
Value tensorSize = getTensorSize(rewriter, loc, adaptor.self());
|
||||||
SmallVector<Value> sizes(getTensorSizes(rewriter, loc, self));
|
rewriter.replaceOp(op, tensorSize);
|
||||||
Value productResult =
|
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
||||||
for (size_t i = 0; i < sizes.size(); i++)
|
|
||||||
productResult =
|
|
||||||
rewriter.create<arith::MulIOp>(loc, productResult, sizes[i]);
|
|
||||||
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, productResult));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue