mirror of https://github.com/llvm/torch-mlir
Add 1D, weight, and reduction support to nll_loss_backward (#729)
This commit adds the following support to the op `nll_loss_backward`: - `input` tensor can be rank-1 - `weight` parameter - `reduction` parameter - `target`, `grad_output`, `total_weight` can be rank-0 - Checks that input tensors are of the expected typepull/706/head snapshot-20220404.369
parent
886ad169e5
commit
5620fe030e
|
@ -28,6 +28,13 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Check if a ranked-tensor has the specified element type.
|
||||
template <typename elementType> static bool hasElementType(Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
Type tensorElementType = tensorType.getElementType();
|
||||
return tensorElementType.isa<elementType>();
|
||||
}
|
||||
|
||||
static Value createElementwiseLinalgGeneric(
|
||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
|
@ -1441,11 +1448,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
|
||||
// for i in range(0, len(input[0])):
|
||||
// for j in range(0, len(input[1])):
|
||||
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
||||
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
||||
namespace {
|
||||
class ConvertAtenNllLossBackwardOp
|
||||
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
||||
|
@ -1456,89 +1458,137 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Value gradOutput = adaptor.grad_output();
|
||||
Value input = adaptor.self();
|
||||
Value target = adaptor.target();
|
||||
Value weight = adaptor.weight();
|
||||
Value gradOutput = adaptor.grad_output();
|
||||
bool weightIsNone = op.weight().getType().isa<Torch::NoneType>();
|
||||
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.ignore_index());
|
||||
Value totalWeight = adaptor.total_weight();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
int inputRank = inputType.getRank();
|
||||
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
||||
Type resultElementType = gradOutputType.getElementType();
|
||||
|
||||
int64_t reduction;
|
||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
|
||||
// TODO: Handle reduction.
|
||||
if (reduction != torch_upstream::Reduction::None)
|
||||
if (!hasElementType<mlir::FloatType>(gradOutput) ||
|
||||
!hasElementType<mlir::FloatType>(gradOutput) ||
|
||||
(!weightIsNone && !hasElementType<mlir::FloatType>(weight))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "reduction along dimensions is not supported.");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weight.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, the weight operand is not incorporated.");
|
||||
|
||||
Value ignoreIndex = adaptor.ignore_index();
|
||||
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
||||
|
||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||
|
||||
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
|
||||
// required.
|
||||
if (inputRank != 2 || targetRank != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input and target to be rank 2 and 1 respectively");
|
||||
op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of "
|
||||
"type float");
|
||||
}
|
||||
|
||||
if (!hasElementType<mlir::IntegerType>(target)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "`target` must be a tensor of integer type");
|
||||
}
|
||||
|
||||
auto outputSize = getTensorSizes(rewriter, loc, input);
|
||||
Value gradInputTensor =
|
||||
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
|
||||
|
||||
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
|
||||
rewriter.getAffineConstantExpr(0));
|
||||
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
|
||||
op->getContext());
|
||||
};
|
||||
|
||||
AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0,
|
||||
rewriter.getAffineDimExpr(0));
|
||||
if (reduction != torch_upstream::Reduction::None || inputRank == 1)
|
||||
gradOutMap = getAffineMapForSingleElementTensor(gradOutput);
|
||||
AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0,
|
||||
rewriter.getAffineDimExpr(0));
|
||||
if (inputRank == 1)
|
||||
targetMap = getAffineMapForSingleElementTensor(target);
|
||||
AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight);
|
||||
AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps{gradOutMap, targetMap, totalWeightMap,
|
||||
resultMap};
|
||||
SmallVector<StringRef> iteratorTypes(inputRank,
|
||||
getParallelIteratorTypeName());
|
||||
|
||||
// The code generation is equivalent to the following pseudo-code:
|
||||
//
|
||||
// for batch_index in len(input.size(0)):
|
||||
// for class_index in len(input.size(1)):
|
||||
// target_elem = target[batch_index]
|
||||
//
|
||||
// if reduction == None:
|
||||
// grad_out_elem = grad_output[batchIndex]
|
||||
// else:
|
||||
// grad_out_elem = grad_output[0]
|
||||
//
|
||||
// if reduction == Mean:
|
||||
// total_weight_elem = total_weight[0]
|
||||
// grad_out_elem /= total_weight_elem
|
||||
//
|
||||
// weight_elem = weight[target_elem] if weight != None else 1
|
||||
//
|
||||
// if target_elem != class_index or target_elem == ignore_index:
|
||||
// grad_input_elem = -weight_elem * grad_out_elem
|
||||
// else:
|
||||
// grad_input_elem = 0
|
||||
// grad_input[batch_index, target_elem] = grad_input_elem
|
||||
//
|
||||
// NOTE: In the case of not batch dimension, `batch_index` essentially
|
||||
// becomes zero.
|
||||
Value gradInput =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, gradInputTensor.getType(),
|
||||
ValueRange{gradOutput, target, totalWeight}, gradInputTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value gradOutElem = args[0];
|
||||
Value targetElem = castIntToIndex(b, loc, args[1]);
|
||||
Value totalWeightElem = args[2];
|
||||
Value classIndex =
|
||||
b.create<linalg::IndexOp>(loc, inputRank - 1);
|
||||
|
||||
if (reduction == torch_upstream::Reduction::Mean) {
|
||||
gradOutElem = b.create<arith::DivFOp>(loc, gradOutElem,
|
||||
totalWeightElem);
|
||||
}
|
||||
|
||||
Value negGradOutElem =
|
||||
b.create<arith::NegFOp>(loc, gradOutElem);
|
||||
Value weightElem = getConstant(b, loc, 1, resultElementType);
|
||||
if (!weightIsNone) {
|
||||
weightElem =
|
||||
b.create<tensor::ExtractOp>(loc, weight, targetElem);
|
||||
}
|
||||
Value weightedNegGradOutElem =
|
||||
b.create<arith::MulFOp>(loc, weightElem, negGradOutElem);
|
||||
|
||||
Value targetNeqClassIndex = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, targetElem, classIndex);
|
||||
Value targetEqIgnoreIndex = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex);
|
||||
Value gradInputIsZero = b.create<arith::OrIOp>(
|
||||
loc, targetNeqClassIndex, targetEqIgnoreIndex);
|
||||
|
||||
Value zero = getConstant(b, loc, 0, resultElementType);
|
||||
Value gradInElem = b.create<arith::SelectOp>(
|
||||
loc, gradInputIsZero, zero, weightedNegGradOutElem);
|
||||
b.create<linalg::YieldOp>(loc, gradInElem);
|
||||
})
|
||||
->getResult(0);
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
// Given there is no reduction `grad_input` size is equal to `input` size.
|
||||
auto outputSize = getTensorSizes(rewriter, loc, input);
|
||||
Value initTensor0 =
|
||||
createZeroInitTensor(rewriter, loc, outputSize, elementType);
|
||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
|
||||
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
|
||||
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
|
||||
rewriter.getAffineDimExpr(1)};
|
||||
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
|
||||
getParallelIteratorTypeName()};
|
||||
auto indexingMaps =
|
||||
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor0.getType(), ValueRange{target, gradOutput},
|
||||
initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), args[0]);
|
||||
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
|
||||
|
||||
// The final result is given by:
|
||||
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, indJ, indTarget);
|
||||
|
||||
// The target index shouldn't be equal to `ignoreIndex`.
|
||||
Value cmpNe = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
|
||||
Value finalPredicate =
|
||||
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
|
||||
Value negate =
|
||||
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
|
||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
||||
loc, finalPredicate, negate, zeroVal);
|
||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -162,6 +162,37 @@ def NllLossModuleBackward_basic(module, tu: TestUtils):
|
|||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
|
||||
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -190,3 +221,298 @@ class NllLossModule_backward_ignore_index(torch.nn.Module):
|
|||
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
||||
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
||||
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
||||
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
||||
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
||||
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
|
Loading…
Reference in New Issue