mirror of https://github.com/llvm/torch-mlir
[Torch] support AtenNllLossForwardOp decomposition (#3833)
parent
70e089802a
commit
2f33f31724
|
@ -9154,6 +9154,12 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: unranked target tensor");
|
op, "Unimplemented: unranked target tensor");
|
||||||
unsigned targetRank = maybeRank.value();
|
unsigned targetRank = maybeRank.value();
|
||||||
|
Value reduction = op.getReduction();
|
||||||
|
int64_t reductionInt;
|
||||||
|
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"reduction should be a constant int!");
|
||||||
|
}
|
||||||
|
|
||||||
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
||||||
// of the form [minibatch] the cross entropy loss decomposes to the
|
// of the form [minibatch] the cross entropy loss decomposes to the
|
||||||
|
@ -9184,10 +9190,19 @@ public:
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value logSoftmax = rewriter.create<AtenLogSoftmaxIntOp>(
|
Value logSoftmax = rewriter.create<AtenLogSoftmaxIntOp>(
|
||||||
loc, self.getType(), self, dim, /*dtype=*/noneVal);
|
loc, self.getType(), self, dim, /*dtype=*/noneVal);
|
||||||
|
|
||||||
|
Type secondType;
|
||||||
|
if (reductionInt == 0) {
|
||||||
|
secondType = target.getType();
|
||||||
|
} else {
|
||||||
|
auto targetType = dyn_cast<BaseTensorType>(target.getType());
|
||||||
|
secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype());
|
||||||
|
}
|
||||||
|
|
||||||
Value nllLoss =
|
Value nllLoss =
|
||||||
rewriter
|
rewriter
|
||||||
.create<AtenNllLossForwardOp>(
|
.create<AtenNllLossForwardOp>(
|
||||||
loc, op.getType(), target.getType(), logSoftmax, target,
|
loc, op.getType(), secondType, logSoftmax, target,
|
||||||
op.getWeight(), op.getReduction(), op.getIgnoreIndex())
|
op.getWeight(), op.getReduction(), op.getIgnoreIndex())
|
||||||
->getResult(0);
|
->getResult(0);
|
||||||
rewriter.replaceOp(op, nllLoss);
|
rewriter.replaceOp(op, nllLoss);
|
||||||
|
@ -9196,6 +9211,235 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose aten::nll_loss_forward according to :
|
||||||
|
// torch/_decomp/decompositions.py and
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html.
|
||||||
|
// The (self, target) can be:
|
||||||
|
// 1. [N, C] and [C],
|
||||||
|
// or
|
||||||
|
// 2. [N] or [].
|
||||||
|
// The weight must be None or 1d where the numel must keep consistent with the
|
||||||
|
// number of classes.
|
||||||
|
class DecomposeAtenNllLossForwardOp
|
||||||
|
: public OpRewritePattern<AtenNllLossForwardOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNllLossForwardOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto ctx = op.getContext();
|
||||||
|
|
||||||
|
auto self = op.getSelf();
|
||||||
|
auto target = op.getTarget();
|
||||||
|
|
||||||
|
auto selfType = dyn_cast<BaseTensorType>(self.getType());
|
||||||
|
auto targetType = dyn_cast<BaseTensorType>(target.getType());
|
||||||
|
|
||||||
|
// constraints.
|
||||||
|
if (!selfType.hasSizes() || !targetType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "require self and target having sizes!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!selfType.hasDtype() || !targetType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "require self and target having dtype!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto selfSizes = selfType.getSizes();
|
||||||
|
auto targetSizes = targetType.getSizes();
|
||||||
|
int64_t selfRank = selfSizes.size();
|
||||||
|
int64_t targetRank = targetSizes.size();
|
||||||
|
if (selfRank <= 0 or selfRank > 2) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D");
|
||||||
|
}
|
||||||
|
if (targetRank > 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"target tensor shoule be 0D or 1D!");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selfRank != 1 or targetRank != 0) {
|
||||||
|
if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and
|
||||||
|
selfSizes[0] != targetSizes[0]) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op,
|
||||||
|
"input tensor and target tensor should have the same batch size!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t numClasses = selfSizes.back();
|
||||||
|
auto weight = op.getWeight();
|
||||||
|
auto weightT = weight.getType();
|
||||||
|
if (!isa<Torch::NoneType>(weightT) && numClasses != kUnknownSize) {
|
||||||
|
auto weightType = dyn_cast<BaseTensorType>(weightT);
|
||||||
|
if (weightType.areAllSizesKnown()) {
|
||||||
|
auto weightSizes = weightType.getSizes();
|
||||||
|
int64_t weightNumel = 1;
|
||||||
|
for (size_t i = 0; i < weightSizes.size(); i++) {
|
||||||
|
weightNumel *= weightSizes[i];
|
||||||
|
}
|
||||||
|
if (weightNumel != numClasses) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "weight tensor should be defined either for all classes or "
|
||||||
|
"no classes!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value reductionValue = op.getReduction();
|
||||||
|
int64_t reduction;
|
||||||
|
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"reduction should be a constant int!");
|
||||||
|
}
|
||||||
|
|
||||||
|
// decomposation.
|
||||||
|
uint64_t channelDim = 1;
|
||||||
|
if (selfRank < 2) {
|
||||||
|
channelDim = 0;
|
||||||
|
}
|
||||||
|
Value channelDimValue = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(channelDim));
|
||||||
|
|
||||||
|
auto ignoreIndex = op.getIgnoreIndex();
|
||||||
|
Value w;
|
||||||
|
if (!isa<Torch::NoneType>(weightT)) {
|
||||||
|
if (selfRank > 1) {
|
||||||
|
auto weightType = dyn_cast<BaseTensorType>(weightT);
|
||||||
|
auto weightSizes = weightType.getSizes();
|
||||||
|
SmallVector<int64_t> newShapeList(selfRank, 1);
|
||||||
|
newShapeList[channelDim] = weightSizes[0];
|
||||||
|
SmallVector<Value> newShapeListValue;
|
||||||
|
for (size_t i = 0; i < newShapeList.size(); ++i) {
|
||||||
|
newShapeListValue.push_back(rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(newShapeList[i])));
|
||||||
|
}
|
||||||
|
Value newShape = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getType<Torch::ListType>(
|
||||||
|
rewriter.getType<Torch::IntType>()),
|
||||||
|
newShapeListValue);
|
||||||
|
auto newType = weightType.getWithSizesAndDtype(newShapeList,
|
||||||
|
weightType.getDtype());
|
||||||
|
w = rewriter.create<AtenViewOp>(loc, newType, weight, newShape);
|
||||||
|
} else {
|
||||||
|
w = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
self = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self, w);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> targetDimSizes(targetSizes);
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
auto condType =
|
||||||
|
ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type());
|
||||||
|
auto unequalCond =
|
||||||
|
rewriter.create<AtenNeScalarOp>(loc, condType, target, ignoreIndex);
|
||||||
|
auto zeroTensorType =
|
||||||
|
ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true));
|
||||||
|
Value zeroTensor =
|
||||||
|
rewriter.create<PrimNumToTensorScalarOp>(loc, zeroTensorType, zero);
|
||||||
|
auto safeTarget = rewriter.create<AtenWhereSelfOp>(
|
||||||
|
loc, target.getType(), unequalCond, target, zeroTensor);
|
||||||
|
|
||||||
|
SmallVector<int64_t> safeTargetShape;
|
||||||
|
for (size_t i = 0; i < targetSizes.size(); ++i) {
|
||||||
|
if (channelDim == i) {
|
||||||
|
safeTargetShape.push_back(1);
|
||||||
|
}
|
||||||
|
safeTargetShape.push_back(targetSizes[i]);
|
||||||
|
}
|
||||||
|
if (channelDim == safeTargetShape.size()) {
|
||||||
|
safeTargetShape.push_back(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gatherType =
|
||||||
|
ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype());
|
||||||
|
auto safeTarget_ = rewriter.create<AtenUnsqueezeOp>(
|
||||||
|
loc, gatherType, safeTarget, channelDimValue);
|
||||||
|
auto falseValue =
|
||||||
|
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
|
||||||
|
auto none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
auto _gather = rewriter.create<AtenGatherOp>(
|
||||||
|
loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()),
|
||||||
|
self, channelDimValue, safeTarget_, falseValue);
|
||||||
|
Value gather = rewriter.create<AtenNegOp>(loc, _gather.getType(), _gather);
|
||||||
|
auto unequalCondType = cast<BaseTensorType>(unequalCond.getType());
|
||||||
|
auto result = rewriter.create<AtenWhereSelfOp>(
|
||||||
|
loc,
|
||||||
|
unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(),
|
||||||
|
selfType.getDtype()),
|
||||||
|
unequalCond,
|
||||||
|
rewriter.create<AtenSqueezeDimOp>(
|
||||||
|
loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()),
|
||||||
|
gather, channelDimValue),
|
||||||
|
zeroTensor);
|
||||||
|
|
||||||
|
Value totalWeight;
|
||||||
|
if (reduction == 0 and selfRank > 1) {
|
||||||
|
auto zeroFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||||
|
Value twSize = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
ValueRange({}));
|
||||||
|
|
||||||
|
totalWeight = rewriter.create<AtenNewFullOp>(
|
||||||
|
loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none);
|
||||||
|
rewriter.replaceOp(op, {result, totalWeight});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<Torch::NoneType>(weightT)) {
|
||||||
|
auto wType = cast<BaseTensorType>(w.getType());
|
||||||
|
auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype());
|
||||||
|
SmallVector<Value> selfSizesValue;
|
||||||
|
for (size_t i = 0; i < selfSizes.size(); ++i) {
|
||||||
|
selfSizesValue.push_back(rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(selfSizes[i])));
|
||||||
|
}
|
||||||
|
auto wSize = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
selfSizesValue);
|
||||||
|
w = rewriter.create<AtenExpandOp>(loc, newWType, w, wSize, falseValue);
|
||||||
|
auto wSumGather = rewriter.create<AtenGatherOp>(
|
||||||
|
loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w,
|
||||||
|
channelDimValue, safeTarget_, falseValue);
|
||||||
|
auto wSumSq = rewriter.create<AtenSqueezeDimOp>(
|
||||||
|
loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()),
|
||||||
|
wSumGather, channelDimValue);
|
||||||
|
auto wSum = rewriter.create<AtenWhereSelfOp>(
|
||||||
|
loc,
|
||||||
|
ValueTensorType::get(ctx, unequalCondType.getSizes(),
|
||||||
|
wType.getDtype()),
|
||||||
|
unequalCond, wSumSq, zeroTensor);
|
||||||
|
|
||||||
|
totalWeight = rewriter.create<AtenSumOp>(loc, op.getType(1), wSum, none);
|
||||||
|
} else {
|
||||||
|
totalWeight =
|
||||||
|
rewriter.create<AtenSumOp>(loc, op.getType(1), unequalCond, none);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultSum =
|
||||||
|
rewriter.create<AtenSumOp>(loc, op.getType(0), result, none);
|
||||||
|
if (reduction == 1) {
|
||||||
|
auto resultMean = rewriter.create<AtenDivTensorOp>(
|
||||||
|
loc, op.getType(0), resultSum, totalWeight);
|
||||||
|
rewriter.replaceOp(op, {resultMean, totalWeight});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {resultSum, totalWeight});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
|
||||||
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
|
||||||
|
@ -10437,6 +10681,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
|
||||||
patterns);
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||||
|
|
|
@ -3022,9 +3022,13 @@ ONNX_XFAIL_SET = {
|
||||||
"NllLossModuleBackward_ignore_index",
|
"NllLossModuleBackward_ignore_index",
|
||||||
"NllLossModule_1D_basic",
|
"NllLossModule_1D_basic",
|
||||||
"NllLossModule_basic",
|
"NllLossModule_basic",
|
||||||
|
"NllLossStaticModule_basic",
|
||||||
|
"NllLossStaticModule_weight_basic",
|
||||||
"NllLossModule_ignore_index_out_of_bounds_basic",
|
"NllLossModule_ignore_index_out_of_bounds_basic",
|
||||||
"NllLossModule_mean_basic",
|
"NllLossModule_mean_basic",
|
||||||
|
"NllLossStaticModule_mean_basic",
|
||||||
"NllLossModule_sum_basic",
|
"NllLossModule_sum_basic",
|
||||||
|
"NllLossStaticModule_sum_basic",
|
||||||
"NormScalarComplexModule_basic",
|
"NormScalarComplexModule_basic",
|
||||||
"NormScalarModule_basic",
|
"NormScalarModule_basic",
|
||||||
"NormScalarOptDimKeepDimComplexModule_basic",
|
"NormScalarOptDimKeepDimComplexModule_basic",
|
||||||
|
|
|
@ -36,6 +36,57 @@ def NllLossModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
([2], torch.int64, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(
|
||||||
|
x, target=y, weight=None, reduction=0, ignore_index=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossStaticModule())
|
||||||
|
def NllLossStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossStaticModule_weight(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
([2], torch.int64, True),
|
||||||
|
([3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y, z):
|
||||||
|
return torch.ops.aten.nll_loss_forward(
|
||||||
|
x, target=y, weight=z, reduction=2, ignore_index=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossStaticModule_weight())
|
||||||
|
def NllLossStaticModule_weight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.rand(2, 3), tu.randint(2, low=0, high=3), torch.tensor([0.3, 0.3, 0.4])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_mean(torch.nn.Module):
|
class NllLossModule_mean(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -60,6 +111,30 @@ def NllLossModule_mean_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossStaticModule_mean(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
([2], torch.int64, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(
|
||||||
|
x, target=y, weight=None, reduction=1, ignore_index=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossStaticModule_mean())
|
||||||
|
def NllLossStaticModule_mean_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_sum(torch.nn.Module):
|
class NllLossModule_sum(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -84,6 +159,30 @@ def NllLossModule_sum_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossStaticModule_sum(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
([2], torch.int64, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Here the 2nd index is ignored.
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.nll_loss_forward(
|
||||||
|
x, target=y, weight=None, reduction=2, ignore_index=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossStaticModule_sum())
|
||||||
|
def NllLossStaticModule_sum_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_1D(torch.nn.Module):
|
class NllLossModule_1D(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue