[Torch] support AtenNllLossForwardOp decomposition (#3833)

pull/3804/head
yyp0 2024-11-06 11:34:48 +08:00 committed by GitHub
parent 70e089802a
commit 2f33f31724
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 349 additions and 1 deletions

View File

@ -9154,6 +9154,12 @@ public:
return rewriter.notifyMatchFailure(
op, "Unimplemented: unranked target tensor");
unsigned targetRank = maybeRank.value();
Value reduction = op.getReduction();
int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
// of the form [minibatch] the cross entropy loss decomposes to the
@ -9184,10 +9190,19 @@ public:
loc, rewriter.getI64IntegerAttr(1));
Value logSoftmax = rewriter.create<AtenLogSoftmaxIntOp>(
loc, self.getType(), self, dim, /*dtype=*/noneVal);
Type secondType;
if (reductionInt == 0) {
secondType = target.getType();
} else {
auto targetType = dyn_cast<BaseTensorType>(target.getType());
secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype());
}
Value nllLoss =
rewriter
.create<AtenNllLossForwardOp>(
loc, op.getType(), target.getType(), logSoftmax, target,
loc, op.getType(), secondType, logSoftmax, target,
op.getWeight(), op.getReduction(), op.getIgnoreIndex())
->getResult(0);
rewriter.replaceOp(op, nllLoss);
@ -9196,6 +9211,235 @@ public:
};
} // namespace
namespace {
// Decompose aten::nll_loss_forward according to :
// torch/_decomp/decompositions.py and
// https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html.
// The (self, target) can be:
// 1. [N, C] and [C],
// or
// 2. [N] or [].
// The weight must be None or 1d where the numel must keep consistent with the
// number of classes.
class DecomposeAtenNllLossForwardOp
: public OpRewritePattern<AtenNllLossForwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNllLossForwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto ctx = op.getContext();
auto self = op.getSelf();
auto target = op.getTarget();
auto selfType = dyn_cast<BaseTensorType>(self.getType());
auto targetType = dyn_cast<BaseTensorType>(target.getType());
// constraints.
if (!selfType.hasSizes() || !targetType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "require self and target having sizes!");
}
if (!selfType.hasDtype() || !targetType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "require self and target having dtype!");
}
auto selfSizes = selfType.getSizes();
auto targetSizes = targetType.getSizes();
int64_t selfRank = selfSizes.size();
int64_t targetRank = targetSizes.size();
if (selfRank <= 0 or selfRank > 2) {
return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D");
}
if (targetRank > 1) {
return rewriter.notifyMatchFailure(op,
"target tensor shoule be 0D or 1D!");
}
if (selfRank != 1 or targetRank != 0) {
if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and
selfSizes[0] != targetSizes[0]) {
return rewriter.notifyMatchFailure(
op,
"input tensor and target tensor should have the same batch size!");
}
}
int64_t numClasses = selfSizes.back();
auto weight = op.getWeight();
auto weightT = weight.getType();
if (!isa<Torch::NoneType>(weightT) && numClasses != kUnknownSize) {
auto weightType = dyn_cast<BaseTensorType>(weightT);
if (weightType.areAllSizesKnown()) {
auto weightSizes = weightType.getSizes();
int64_t weightNumel = 1;
for (size_t i = 0; i < weightSizes.size(); i++) {
weightNumel *= weightSizes[i];
}
if (weightNumel != numClasses) {
return rewriter.notifyMatchFailure(
op, "weight tensor should be defined either for all classes or "
"no classes!");
}
}
}
Value reductionValue = op.getReduction();
int64_t reduction;
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}
// decomposation.
uint64_t channelDim = 1;
if (selfRank < 2) {
channelDim = 0;
}
Value channelDimValue = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(channelDim));
auto ignoreIndex = op.getIgnoreIndex();
Value w;
if (!isa<Torch::NoneType>(weightT)) {
if (selfRank > 1) {
auto weightType = dyn_cast<BaseTensorType>(weightT);
auto weightSizes = weightType.getSizes();
SmallVector<int64_t> newShapeList(selfRank, 1);
newShapeList[channelDim] = weightSizes[0];
SmallVector<Value> newShapeListValue;
for (size_t i = 0; i < newShapeList.size(); ++i) {
newShapeListValue.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(newShapeList[i])));
}
Value newShape = rewriter.create<PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
newShapeListValue);
auto newType = weightType.getWithSizesAndDtype(newShapeList,
weightType.getDtype());
w = rewriter.create<AtenViewOp>(loc, newType, weight, newShape);
} else {
w = weight;
}
self = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self, w);
}
SmallVector<int64_t> targetDimSizes(targetSizes);
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto condType =
ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type());
auto unequalCond =
rewriter.create<AtenNeScalarOp>(loc, condType, target, ignoreIndex);
auto zeroTensorType =
ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true));
Value zeroTensor =
rewriter.create<PrimNumToTensorScalarOp>(loc, zeroTensorType, zero);
auto safeTarget = rewriter.create<AtenWhereSelfOp>(
loc, target.getType(), unequalCond, target, zeroTensor);
SmallVector<int64_t> safeTargetShape;
for (size_t i = 0; i < targetSizes.size(); ++i) {
if (channelDim == i) {
safeTargetShape.push_back(1);
}
safeTargetShape.push_back(targetSizes[i]);
}
if (channelDim == safeTargetShape.size()) {
safeTargetShape.push_back(1);
}
auto gatherType =
ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype());
auto safeTarget_ = rewriter.create<AtenUnsqueezeOp>(
loc, gatherType, safeTarget, channelDimValue);
auto falseValue =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
auto none = rewriter.create<ConstantNoneOp>(loc);
auto _gather = rewriter.create<AtenGatherOp>(
loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()),
self, channelDimValue, safeTarget_, falseValue);
Value gather = rewriter.create<AtenNegOp>(loc, _gather.getType(), _gather);
auto unequalCondType = cast<BaseTensorType>(unequalCond.getType());
auto result = rewriter.create<AtenWhereSelfOp>(
loc,
unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(),
selfType.getDtype()),
unequalCond,
rewriter.create<AtenSqueezeDimOp>(
loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()),
gather, channelDimValue),
zeroTensor);
Value totalWeight;
if (reduction == 0 and selfRank > 1) {
auto zeroFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value twSize = rewriter.create<PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
ValueRange({}));
totalWeight = rewriter.create<AtenNewFullOp>(
loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none);
rewriter.replaceOp(op, {result, totalWeight});
return success();
}
if (!isa<Torch::NoneType>(weightT)) {
auto wType = cast<BaseTensorType>(w.getType());
auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype());
SmallVector<Value> selfSizesValue;
for (size_t i = 0; i < selfSizes.size(); ++i) {
selfSizesValue.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfSizes[i])));
}
auto wSize = rewriter.create<PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
selfSizesValue);
w = rewriter.create<AtenExpandOp>(loc, newWType, w, wSize, falseValue);
auto wSumGather = rewriter.create<AtenGatherOp>(
loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w,
channelDimValue, safeTarget_, falseValue);
auto wSumSq = rewriter.create<AtenSqueezeDimOp>(
loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()),
wSumGather, channelDimValue);
auto wSum = rewriter.create<AtenWhereSelfOp>(
loc,
ValueTensorType::get(ctx, unequalCondType.getSizes(),
wType.getDtype()),
unequalCond, wSumSq, zeroTensor);
totalWeight = rewriter.create<AtenSumOp>(loc, op.getType(1), wSum, none);
} else {
totalWeight =
rewriter.create<AtenSumOp>(loc, op.getType(1), unequalCond, none);
}
auto resultSum =
rewriter.create<AtenSumOp>(loc, op.getType(0), result, none);
if (reduction == 1) {
auto resultMean = rewriter.create<AtenDivTensorOp>(
loc, op.getType(0), resultSum, totalWeight);
rewriter.replaceOp(op, {resultMean, totalWeight});
return success();
}
rewriter.replaceOp(op, {resultSum, totalWeight});
return success();
}
};
} // namespace
namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@ -10437,6 +10681,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);

View File

@ -3022,9 +3022,13 @@ ONNX_XFAIL_SET = {
"NllLossModuleBackward_ignore_index",
"NllLossModule_1D_basic",
"NllLossModule_basic",
"NllLossStaticModule_basic",
"NllLossStaticModule_weight_basic",
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossStaticModule_mean_basic",
"NllLossModule_sum_basic",
"NllLossStaticModule_sum_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",

View File

@ -36,6 +36,57 @@ def NllLossModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3], torch.float32, True),
([2], torch.int64, True),
]
)
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(
x, target=y, weight=None, reduction=0, ignore_index=2
)
@register_test_case(module_factory=lambda: NllLossStaticModule())
def NllLossStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossStaticModule_weight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3], torch.float32, True),
([2], torch.int64, True),
([3], torch.float32, True),
]
)
# Here the 2nd index is ignored.
def forward(self, x, y, z):
return torch.ops.aten.nll_loss_forward(
x, target=y, weight=z, reduction=2, ignore_index=2
)
@register_test_case(module_factory=lambda: NllLossStaticModule_weight())
def NllLossStaticModule_weight_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 3), tu.randint(2, low=0, high=3), torch.tensor([0.3, 0.3, 0.4])
)
class NllLossModule_mean(torch.nn.Module):
def __init__(self):
super().__init__()
@ -60,6 +111,30 @@ def NllLossModule_mean_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossStaticModule_mean(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3], torch.float32, True),
([2], torch.int64, True),
]
)
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(
x, target=y, weight=None, reduction=1, ignore_index=2
)
@register_test_case(module_factory=lambda: NllLossStaticModule_mean())
def NllLossStaticModule_mean_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossModule_sum(torch.nn.Module):
def __init__(self):
super().__init__()
@ -84,6 +159,30 @@ def NllLossModule_sum_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossStaticModule_sum(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3], torch.float32, True),
([2], torch.int64, True),
]
)
# Here the 2nd index is ignored.
def forward(self, x, y):
return torch.ops.aten.nll_loss_forward(
x, target=y, weight=None, reduction=2, ignore_index=2
)
@register_test_case(module_factory=lambda: NllLossStaticModule_sum())
def NllLossStaticModule_sum_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossModule_1D(torch.nn.Module):
def __init__(self):
super().__init__()