mirror of https://github.com/llvm/torch-mlir
[Torch] support AtenNllLossForwardOp decomposition (#3833)
parent
70e089802a
commit
2f33f31724
|
@ -9154,6 +9154,12 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: unranked target tensor");
|
||||
unsigned targetRank = maybeRank.value();
|
||||
Value reduction = op.getReduction();
|
||||
int64_t reductionInt;
|
||||
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"reduction should be a constant int!");
|
||||
}
|
||||
|
||||
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
||||
// of the form [minibatch] the cross entropy loss decomposes to the
|
||||
|
@ -9184,10 +9190,19 @@ public:
|
|||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value logSoftmax = rewriter.create<AtenLogSoftmaxIntOp>(
|
||||
loc, self.getType(), self, dim, /*dtype=*/noneVal);
|
||||
|
||||
Type secondType;
|
||||
if (reductionInt == 0) {
|
||||
secondType = target.getType();
|
||||
} else {
|
||||
auto targetType = dyn_cast<BaseTensorType>(target.getType());
|
||||
secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype());
|
||||
}
|
||||
|
||||
Value nllLoss =
|
||||
rewriter
|
||||
.create<AtenNllLossForwardOp>(
|
||||
loc, op.getType(), target.getType(), logSoftmax, target,
|
||||
loc, op.getType(), secondType, logSoftmax, target,
|
||||
op.getWeight(), op.getReduction(), op.getIgnoreIndex())
|
||||
->getResult(0);
|
||||
rewriter.replaceOp(op, nllLoss);
|
||||
|
@ -9196,6 +9211,235 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose aten::nll_loss_forward according to :
|
||||
// torch/_decomp/decompositions.py and
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html.
|
||||
// The (self, target) can be:
|
||||
// 1. [N, C] and [C],
|
||||
// or
|
||||
// 2. [N] or [].
|
||||
// The weight must be None or 1d where the numel must keep consistent with the
|
||||
// number of classes.
|
||||
class DecomposeAtenNllLossForwardOp
|
||||
: public OpRewritePattern<AtenNllLossForwardOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNllLossForwardOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto ctx = op.getContext();
|
||||
|
||||
auto self = op.getSelf();
|
||||
auto target = op.getTarget();
|
||||
|
||||
auto selfType = dyn_cast<BaseTensorType>(self.getType());
|
||||
auto targetType = dyn_cast<BaseTensorType>(target.getType());
|
||||
|
||||
// constraints.
|
||||
if (!selfType.hasSizes() || !targetType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "require self and target having sizes!");
|
||||
}
|
||||
|
||||
if (!selfType.hasDtype() || !targetType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "require self and target having dtype!");
|
||||
}
|
||||
|
||||
auto selfSizes = selfType.getSizes();
|
||||
auto targetSizes = targetType.getSizes();
|
||||
int64_t selfRank = selfSizes.size();
|
||||
int64_t targetRank = targetSizes.size();
|
||||
if (selfRank <= 0 or selfRank > 2) {
|
||||
return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D");
|
||||
}
|
||||
if (targetRank > 1) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"target tensor shoule be 0D or 1D!");
|
||||
}
|
||||
|
||||
if (selfRank != 1 or targetRank != 0) {
|
||||
if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and
|
||||
selfSizes[0] != targetSizes[0]) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"input tensor and target tensor should have the same batch size!");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t numClasses = selfSizes.back();
|
||||
auto weight = op.getWeight();
|
||||
auto weightT = weight.getType();
|
||||
if (!isa<Torch::NoneType>(weightT) && numClasses != kUnknownSize) {
|
||||
auto weightType = dyn_cast<BaseTensorType>(weightT);
|
||||
if (weightType.areAllSizesKnown()) {
|
||||
auto weightSizes = weightType.getSizes();
|
||||
int64_t weightNumel = 1;
|
||||
for (size_t i = 0; i < weightSizes.size(); i++) {
|
||||
weightNumel *= weightSizes[i];
|
||||
}
|
||||
if (weightNumel != numClasses) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "weight tensor should be defined either for all classes or "
|
||||
"no classes!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value reductionValue = op.getReduction();
|
||||
int64_t reduction;
|
||||
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"reduction should be a constant int!");
|
||||
}
|
||||
|
||||
// decomposation.
|
||||
uint64_t channelDim = 1;
|
||||
if (selfRank < 2) {
|
||||
channelDim = 0;
|
||||
}
|
||||
Value channelDimValue = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(channelDim));
|
||||
|
||||
auto ignoreIndex = op.getIgnoreIndex();
|
||||
Value w;
|
||||
if (!isa<Torch::NoneType>(weightT)) {
|
||||
if (selfRank > 1) {
|
||||
auto weightType = dyn_cast<BaseTensorType>(weightT);
|
||||
auto weightSizes = weightType.getSizes();
|
||||
SmallVector<int64_t> newShapeList(selfRank, 1);
|
||||
newShapeList[channelDim] = weightSizes[0];
|
||||
SmallVector<Value> newShapeListValue;
|
||||
for (size_t i = 0; i < newShapeList.size(); ++i) {
|
||||
newShapeListValue.push_back(rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(newShapeList[i])));
|
||||
}
|
||||
Value newShape = rewriter.create<PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
newShapeListValue);
|
||||
auto newType = weightType.getWithSizesAndDtype(newShapeList,
|
||||
weightType.getDtype());
|
||||
w = rewriter.create<AtenViewOp>(loc, newType, weight, newShape);
|
||||
} else {
|
||||
w = weight;
|
||||
}
|
||||
|
||||
self = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self, w);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> targetDimSizes(targetSizes);
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
auto condType =
|
||||
ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type());
|
||||
auto unequalCond =
|
||||
rewriter.create<AtenNeScalarOp>(loc, condType, target, ignoreIndex);
|
||||
auto zeroTensorType =
|
||||
ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true));
|
||||
Value zeroTensor =
|
||||
rewriter.create<PrimNumToTensorScalarOp>(loc, zeroTensorType, zero);
|
||||
auto safeTarget = rewriter.create<AtenWhereSelfOp>(
|
||||
loc, target.getType(), unequalCond, target, zeroTensor);
|
||||
|
||||
SmallVector<int64_t> safeTargetShape;
|
||||
for (size_t i = 0; i < targetSizes.size(); ++i) {
|
||||
if (channelDim == i) {
|
||||
safeTargetShape.push_back(1);
|
||||
}
|
||||
safeTargetShape.push_back(targetSizes[i]);
|
||||
}
|
||||
if (channelDim == safeTargetShape.size()) {
|
||||
safeTargetShape.push_back(1);
|
||||
}
|
||||
|
||||
auto gatherType =
|
||||
ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype());
|
||||
auto safeTarget_ = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, gatherType, safeTarget, channelDimValue);
|
||||
auto falseValue =
|
||||
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
|
||||
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto _gather = rewriter.create<AtenGatherOp>(
|
||||
loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()),
|
||||
self, channelDimValue, safeTarget_, falseValue);
|
||||
Value gather = rewriter.create<AtenNegOp>(loc, _gather.getType(), _gather);
|
||||
auto unequalCondType = cast<BaseTensorType>(unequalCond.getType());
|
||||
auto result = rewriter.create<AtenWhereSelfOp>(
|
||||
loc,
|
||||
unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(),
|
||||
selfType.getDtype()),
|
||||
unequalCond,
|
||||
rewriter.create<AtenSqueezeDimOp>(
|
||||
loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()),
|
||||
gather, channelDimValue),
|
||||
zeroTensor);
|
||||
|
||||
Value totalWeight;
|
||||
if (reduction == 0 and selfRank > 1) {
|
||||
auto zeroFloat =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value twSize = rewriter.create<PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
ValueRange({}));
|
||||
|
||||
totalWeight = rewriter.create<AtenNewFullOp>(
|
||||
loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none);
|
||||
rewriter.replaceOp(op, {result, totalWeight});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
if (!isa<Torch::NoneType>(weightT)) {
|
||||
auto wType = cast<BaseTensorType>(w.getType());
|
||||
auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype());
|
||||
SmallVector<Value> selfSizesValue;
|
||||
for (size_t i = 0; i < selfSizes.size(); ++i) {
|
||||
selfSizesValue.push_back(rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(selfSizes[i])));
|
||||
}
|
||||
auto wSize = rewriter.create<PrimListConstructOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
selfSizesValue);
|
||||
w = rewriter.create<AtenExpandOp>(loc, newWType, w, wSize, falseValue);
|
||||
auto wSumGather = rewriter.create<AtenGatherOp>(
|
||||
loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w,
|
||||
channelDimValue, safeTarget_, falseValue);
|
||||
auto wSumSq = rewriter.create<AtenSqueezeDimOp>(
|
||||
loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()),
|
||||
wSumGather, channelDimValue);
|
||||
auto wSum = rewriter.create<AtenWhereSelfOp>(
|
||||
loc,
|
||||
ValueTensorType::get(ctx, unequalCondType.getSizes(),
|
||||
wType.getDtype()),
|
||||
unequalCond, wSumSq, zeroTensor);
|
||||
|
||||
totalWeight = rewriter.create<AtenSumOp>(loc, op.getType(1), wSum, none);
|
||||
} else {
|
||||
totalWeight =
|
||||
rewriter.create<AtenSumOp>(loc, op.getType(1), unequalCond, none);
|
||||
}
|
||||
|
||||
auto resultSum =
|
||||
rewriter.create<AtenSumOp>(loc, op.getType(0), result, none);
|
||||
if (reduction == 1) {
|
||||
auto resultMean = rewriter.create<AtenDivTensorOp>(
|
||||
loc, op.getType(0), resultSum, totalWeight);
|
||||
rewriter.replaceOp(op, {resultMean, totalWeight});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {resultSum, totalWeight});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
||||
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
||||
|
@ -10437,6 +10681,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||
|
|
|
@ -3022,9 +3022,13 @@ ONNX_XFAIL_SET = {
|
|||
"NllLossModuleBackward_ignore_index",
|
||||
"NllLossModule_1D_basic",
|
||||
"NllLossModule_basic",
|
||||
"NllLossStaticModule_basic",
|
||||
"NllLossStaticModule_weight_basic",
|
||||
"NllLossModule_ignore_index_out_of_bounds_basic",
|
||||
"NllLossModule_mean_basic",
|
||||
"NllLossStaticModule_mean_basic",
|
||||
"NllLossModule_sum_basic",
|
||||
"NllLossStaticModule_sum_basic",
|
||||
"NormScalarComplexModule_basic",
|
||||
"NormScalarModule_basic",
|
||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||
|
|
|
@ -36,6 +36,57 @@ def NllLossModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossStaticModule())
|
||||
def NllLossStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossStaticModule_weight(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
([3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y, z):
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=z, reduction=2, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossStaticModule_weight())
|
||||
def NllLossStaticModule_weight_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 3), tu.randint(2, low=0, high=3), torch.tensor([0.3, 0.3, 0.4])
|
||||
)
|
||||
|
||||
|
||||
class NllLossModule_mean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -60,6 +111,30 @@ def NllLossModule_mean_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossStaticModule_mean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=1, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossStaticModule_mean())
|
||||
def NllLossStaticModule_mean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_sum(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -84,6 +159,30 @@ def NllLossModule_sum_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossStaticModule_sum(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
]
|
||||
)
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(
|
||||
x, target=y, weight=None, reduction=2, ignore_index=2
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossStaticModule_sum())
|
||||
def NllLossStaticModule_sum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_1D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue