mirror of https://github.com/llvm/torch-mlir
Add 1D, weight, and reduction support to nll_loss_backward (#729)
This commit adds the following support to the op `nll_loss_backward`: - `input` tensor can be rank-1 - `weight` parameter - `reduction` parameter - `target`, `grad_output`, `total_weight` can be rank-0 - Checks that input tensors are of the expected typepull/706/head snapshot-20220404.369
parent
886ad169e5
commit
5620fe030e
|
@ -28,6 +28,13 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
// Check if a ranked-tensor has the specified element type.
|
||||||
|
template <typename elementType> static bool hasElementType(Value tensor) {
|
||||||
|
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||||
|
Type tensorElementType = tensorType.getElementType();
|
||||||
|
return tensorElementType.isa<elementType>();
|
||||||
|
}
|
||||||
|
|
||||||
static Value createElementwiseLinalgGeneric(
|
static Value createElementwiseLinalgGeneric(
|
||||||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||||
Type resultElementType,
|
Type resultElementType,
|
||||||
|
@ -1441,11 +1448,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
|
|
||||||
// for i in range(0, len(input[0])):
|
|
||||||
// for j in range(0, len(input[1])):
|
|
||||||
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
|
||||||
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenNllLossBackwardOp
|
class ConvertAtenNllLossBackwardOp
|
||||||
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
||||||
|
@ -1456,89 +1458,137 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
Value gradOutput = adaptor.grad_output();
|
||||||
Value input = adaptor.self();
|
Value input = adaptor.self();
|
||||||
Value target = adaptor.target();
|
Value target = adaptor.target();
|
||||||
Value weight = adaptor.weight();
|
Value weight = adaptor.weight();
|
||||||
Value gradOutput = adaptor.grad_output();
|
bool weightIsNone = op.weight().getType().isa<Torch::NoneType>();
|
||||||
|
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.ignore_index());
|
||||||
|
Value totalWeight = adaptor.total_weight();
|
||||||
|
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
int inputRank = inputType.getRank();
|
||||||
|
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
||||||
|
Type resultElementType = gradOutputType.getElementType();
|
||||||
|
|
||||||
int64_t reduction;
|
int64_t reduction;
|
||||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
|
|
||||||
// TODO: Handle reduction.
|
if (!hasElementType<mlir::FloatType>(gradOutput) ||
|
||||||
if (reduction != torch_upstream::Reduction::None)
|
!hasElementType<mlir::FloatType>(gradOutput) ||
|
||||||
|
(!weightIsNone && !hasElementType<mlir::FloatType>(weight))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "reduction along dimensions is not supported.");
|
op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of "
|
||||||
|
"type float");
|
||||||
// TODO: Incorporate the weight argument.
|
|
||||||
if (!weight.getType().isa<Torch::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Unimplemented, the weight operand is not incorporated.");
|
|
||||||
|
|
||||||
Value ignoreIndex = adaptor.ignore_index();
|
|
||||||
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
|
||||||
|
|
||||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
|
||||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
|
||||||
|
|
||||||
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
|
|
||||||
// required.
|
|
||||||
if (inputRank != 2 || targetRank != 1) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "expected input and target to be rank 2 and 1 respectively");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!hasElementType<mlir::IntegerType>(target)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "`target` must be a tensor of integer type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outputSize = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value gradInputTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
|
||||||
|
|
||||||
|
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
|
||||||
|
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
||||||
|
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
|
||||||
|
rewriter.getAffineConstantExpr(0));
|
||||||
|
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
|
||||||
|
op->getContext());
|
||||||
|
};
|
||||||
|
|
||||||
|
AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0,
|
||||||
|
rewriter.getAffineDimExpr(0));
|
||||||
|
if (reduction != torch_upstream::Reduction::None || inputRank == 1)
|
||||||
|
gradOutMap = getAffineMapForSingleElementTensor(gradOutput);
|
||||||
|
AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0,
|
||||||
|
rewriter.getAffineDimExpr(0));
|
||||||
|
if (inputRank == 1)
|
||||||
|
targetMap = getAffineMapForSingleElementTensor(target);
|
||||||
|
AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight);
|
||||||
|
AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps{gradOutMap, targetMap, totalWeightMap,
|
||||||
|
resultMap};
|
||||||
|
SmallVector<StringRef> iteratorTypes(inputRank,
|
||||||
|
getParallelIteratorTypeName());
|
||||||
|
|
||||||
|
// The code generation is equivalent to the following pseudo-code:
|
||||||
|
//
|
||||||
|
// for batch_index in len(input.size(0)):
|
||||||
|
// for class_index in len(input.size(1)):
|
||||||
|
// target_elem = target[batch_index]
|
||||||
|
//
|
||||||
|
// if reduction == None:
|
||||||
|
// grad_out_elem = grad_output[batchIndex]
|
||||||
|
// else:
|
||||||
|
// grad_out_elem = grad_output[0]
|
||||||
|
//
|
||||||
|
// if reduction == Mean:
|
||||||
|
// total_weight_elem = total_weight[0]
|
||||||
|
// grad_out_elem /= total_weight_elem
|
||||||
|
//
|
||||||
|
// weight_elem = weight[target_elem] if weight != None else 1
|
||||||
|
//
|
||||||
|
// if target_elem != class_index or target_elem == ignore_index:
|
||||||
|
// grad_input_elem = -weight_elem * grad_out_elem
|
||||||
|
// else:
|
||||||
|
// grad_input_elem = 0
|
||||||
|
// grad_input[batch_index, target_elem] = grad_input_elem
|
||||||
|
//
|
||||||
|
// NOTE: In the case of not batch dimension, `batch_index` essentially
|
||||||
|
// becomes zero.
|
||||||
|
Value gradInput =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, gradInputTensor.getType(),
|
||||||
|
ValueRange{gradOutput, target, totalWeight}, gradInputTensor,
|
||||||
|
indexingMaps, iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value gradOutElem = args[0];
|
||||||
|
Value targetElem = castIntToIndex(b, loc, args[1]);
|
||||||
|
Value totalWeightElem = args[2];
|
||||||
|
Value classIndex =
|
||||||
|
b.create<linalg::IndexOp>(loc, inputRank - 1);
|
||||||
|
|
||||||
|
if (reduction == torch_upstream::Reduction::Mean) {
|
||||||
|
gradOutElem = b.create<arith::DivFOp>(loc, gradOutElem,
|
||||||
|
totalWeightElem);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value negGradOutElem =
|
||||||
|
b.create<arith::NegFOp>(loc, gradOutElem);
|
||||||
|
Value weightElem = getConstant(b, loc, 1, resultElementType);
|
||||||
|
if (!weightIsNone) {
|
||||||
|
weightElem =
|
||||||
|
b.create<tensor::ExtractOp>(loc, weight, targetElem);
|
||||||
|
}
|
||||||
|
Value weightedNegGradOutElem =
|
||||||
|
b.create<arith::MulFOp>(loc, weightElem, negGradOutElem);
|
||||||
|
|
||||||
|
Value targetNeqClassIndex = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::ne, targetElem, classIndex);
|
||||||
|
Value targetEqIgnoreIndex = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex);
|
||||||
|
Value gradInputIsZero = b.create<arith::OrIOp>(
|
||||||
|
loc, targetNeqClassIndex, targetEqIgnoreIndex);
|
||||||
|
|
||||||
|
Value zero = getConstant(b, loc, 0, resultElementType);
|
||||||
|
Value gradInElem = b.create<arith::SelectOp>(
|
||||||
|
loc, gradInputIsZero, zero, weightedNegGradOutElem);
|
||||||
|
b.create<linalg::YieldOp>(loc, gradInElem);
|
||||||
|
})
|
||||||
|
->getResult(0);
|
||||||
|
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = getTypeConverter()
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
|
||||||
Type elementType = resultType.getElementType();
|
|
||||||
|
|
||||||
// Given there is no reduction `grad_input` size is equal to `input` size.
|
|
||||||
auto outputSize = getTensorSizes(rewriter, loc, input);
|
|
||||||
Value initTensor0 =
|
|
||||||
createZeroInitTensor(rewriter, loc, outputSize, elementType);
|
|
||||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, rewriter.getZeroAttr(elementType));
|
|
||||||
|
|
||||||
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
|
|
||||||
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
|
|
||||||
rewriter.getAffineDimExpr(1)};
|
|
||||||
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
|
|
||||||
getParallelIteratorTypeName()};
|
|
||||||
auto indexingMaps =
|
|
||||||
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
|
|
||||||
Value finalRes =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initTensor0.getType(), ValueRange{target, gradOutput},
|
|
||||||
initTensor0,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
|
||||||
loc, rewriter.getIndexType(), args[0]);
|
|
||||||
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
|
|
||||||
|
|
||||||
// The final result is given by:
|
|
||||||
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
|
||||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::eq, indJ, indTarget);
|
|
||||||
|
|
||||||
// The target index shouldn't be equal to `ignoreIndex`.
|
|
||||||
Value cmpNe = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
|
|
||||||
Value finalPredicate =
|
|
||||||
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
|
|
||||||
Value negate =
|
|
||||||
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
|
|
||||||
Value selectFinal = rewriter.create<arith::SelectOp>(
|
|
||||||
loc, finalPredicate, negate, zeroVal);
|
|
||||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -162,6 +162,37 @@ def NllLossModuleBackward_basic(module, tu: TestUtils):
|
||||||
torch.tensor(3.))
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backwardWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=10,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||||
|
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(4), torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -190,3 +221,298 @@ class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||||
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
torch.tensor(3.))
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backwardMean(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=1,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
||||||
|
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=1,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||||
|
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(4), torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backwardSum(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=2,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
||||||
|
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=2,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||||
|
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(4), torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1D(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=10,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
||||||
|
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1DWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=10,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||||
|
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(3), torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=1,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
||||||
|
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=1,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||||
|
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(3), torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=2,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
||||||
|
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, weight, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||||
|
input,
|
||||||
|
target=target,
|
||||||
|
weight=weight,
|
||||||
|
reduction=2,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||||
|
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||||
|
torch.rand(3), torch.tensor(3.))
|
||||||
|
|
Loading…
Reference in New Issue