Add 1D, weight, and reduction support to nll_loss_backward (#729)

This commit adds the following support to the op `nll_loss_backward`:
- `input` tensor can be rank-1
- `weight` parameter
- `reduction` parameter
- `target`, `grad_output`, `total_weight` can be rank-0
- Checks that input tensors are of the expected type
pull/706/head snapshot-20220404.369
Ramiro Leal-Cavazos 2022-04-04 10:57:49 -07:00 committed by GitHub
parent 886ad169e5
commit 5620fe030e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 450 additions and 74 deletions

View File

@ -28,6 +28,13 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// Check if a ranked-tensor has the specified element type.
template <typename elementType> static bool hasElementType(Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
Type tensorElementType = tensorType.getElementType();
return tensorElementType.isa<elementType>();
}
static Value createElementwiseLinalgGeneric( static Value createElementwiseLinalgGeneric(
OpBuilder &b, Location loc, ValueRange tensorOperands, OpBuilder &b, Location loc, ValueRange tensorOperands,
Type resultElementType, Type resultElementType,
@ -1441,11 +1448,6 @@ public:
}; };
} // namespace } // namespace
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
// for i in range(0, len(input[0])):
// for j in range(0, len(input[1])):
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
// TODO: `weight` and `reduction` operands are still to be taken care of.
namespace { namespace {
class ConvertAtenNllLossBackwardOp class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> { : public OpConversionPattern<AtenNllLossBackwardOp> {
@ -1456,89 +1458,137 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
Location loc = op->getLoc(); Location loc = op->getLoc();
Value gradOutput = adaptor.grad_output();
Value input = adaptor.self(); Value input = adaptor.self();
Value target = adaptor.target(); Value target = adaptor.target();
Value weight = adaptor.weight(); Value weight = adaptor.weight();
Value gradOutput = adaptor.grad_output(); bool weightIsNone = op.weight().getType().isa<Torch::NoneType>();
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.ignore_index());
Value totalWeight = adaptor.total_weight();
auto inputType = input.getType().cast<RankedTensorType>();
int inputRank = inputType.getRank();
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
Type resultElementType = gradOutputType.getElementType();
int64_t reduction; int64_t reduction;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant"); return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Handle reduction. if (!hasElementType<mlir::FloatType>(gradOutput) ||
if (reduction != torch_upstream::Reduction::None) !hasElementType<mlir::FloatType>(gradOutput) ||
(!weightIsNone && !hasElementType<mlir::FloatType>(weight))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "reduction along dimensions is not supported."); op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of "
"type float");
// TODO: Incorporate the weight argument.
if (!weight.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");
Value ignoreIndex = adaptor.ignore_index();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
// required.
if (inputRank != 2 || targetRank != 1) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank 2 and 1 respectively");
} }
if (!hasElementType<mlir::IntegerType>(target)) {
return rewriter.notifyMatchFailure(
op, "`target` must be a tensor of integer type");
}
auto outputSize = getTensorSizes(rewriter, loc, input);
Value gradInputTensor =
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
auto tensorType = tensor.getType().cast<RankedTensorType>();
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
rewriter.getAffineConstantExpr(0));
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
op->getContext());
};
AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(0));
if (reduction != torch_upstream::Reduction::None || inputRank == 1)
gradOutMap = getAffineMapForSingleElementTensor(gradOutput);
AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(0));
if (inputRank == 1)
targetMap = getAffineMapForSingleElementTensor(target);
AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight);
AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<AffineMap> indexingMaps{gradOutMap, targetMap, totalWeightMap,
resultMap};
SmallVector<StringRef> iteratorTypes(inputRank,
getParallelIteratorTypeName());
// The code generation is equivalent to the following pseudo-code:
//
// for batch_index in len(input.size(0)):
// for class_index in len(input.size(1)):
// target_elem = target[batch_index]
//
// if reduction == None:
// grad_out_elem = grad_output[batchIndex]
// else:
// grad_out_elem = grad_output[0]
//
// if reduction == Mean:
// total_weight_elem = total_weight[0]
// grad_out_elem /= total_weight_elem
//
// weight_elem = weight[target_elem] if weight != None else 1
//
// if target_elem != class_index or target_elem == ignore_index:
// grad_input_elem = -weight_elem * grad_out_elem
// else:
// grad_input_elem = 0
// grad_input[batch_index, target_elem] = grad_input_elem
//
// NOTE: In the case of not batch dimension, `batch_index` essentially
// becomes zero.
Value gradInput =
rewriter
.create<linalg::GenericOp>(
loc, gradInputTensor.getType(),
ValueRange{gradOutput, target, totalWeight}, gradInputTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gradOutElem = args[0];
Value targetElem = castIntToIndex(b, loc, args[1]);
Value totalWeightElem = args[2];
Value classIndex =
b.create<linalg::IndexOp>(loc, inputRank - 1);
if (reduction == torch_upstream::Reduction::Mean) {
gradOutElem = b.create<arith::DivFOp>(loc, gradOutElem,
totalWeightElem);
}
Value negGradOutElem =
b.create<arith::NegFOp>(loc, gradOutElem);
Value weightElem = getConstant(b, loc, 1, resultElementType);
if (!weightIsNone) {
weightElem =
b.create<tensor::ExtractOp>(loc, weight, targetElem);
}
Value weightedNegGradOutElem =
b.create<arith::MulFOp>(loc, weightElem, negGradOutElem);
Value targetNeqClassIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, targetElem, classIndex);
Value targetEqIgnoreIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex);
Value gradInputIsZero = b.create<arith::OrIOp>(
loc, targetNeqClassIndex, targetEqIgnoreIndex);
Value zero = getConstant(b, loc, 0, resultElementType);
Value gradInElem = b.create<arith::SelectOp>(
loc, gradInputIsZero, zero, weightedNegGradOutElem);
b.create<linalg::YieldOp>(loc, gradInElem);
})
->getResult(0);
RankedTensorType resultType = getTypeConverter() RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
Type elementType = resultType.getElementType();
// Given there is no reduction `grad_input` size is equal to `input` size.
auto outputSize = getTensorSizes(rewriter, loc, input);
Value initTensor0 =
createZeroInitTensor(rewriter, loc, outputSize, elementType);
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(1)};
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
getParallelIteratorTypeName()};
auto indexingMaps =
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor0.getType(), ValueRange{target, gradOutput},
initTensor0,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
// The final result is given by:
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indJ, indTarget);
// The target index shouldn't be equal to `ignoreIndex`.
Value cmpNe = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
Value finalPredicate =
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
Value selectFinal = rewriter.create<arith::SelectOp>(
loc, finalPredicate, negate, zeroVal);
b.create<linalg::YieldOp>(loc, selectFinal);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success(); return success();
} }
}; };

View File

@ -162,6 +162,37 @@ def NllLossModuleBackward_basic(module, tu: TestUtils):
torch.tensor(3.)) torch.tensor(3.))
class NllLossModule_backwardWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=0,
ignore_index=10,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.rand(4), torch.tensor(3.))
class NllLossModule_backward_ignore_index(torch.nn.Module): class NllLossModule_backward_ignore_index(torch.nn.Module):
def __init__(self): def __init__(self):
@ -190,3 +221,298 @@ class NllLossModule_backward_ignore_index(torch.nn.Module):
def NllLossModuleBackward_ignore_index(module, tu: TestUtils): def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.tensor(3.)) torch.tensor(3.))
class NllLossModule_backwardMean(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=None,
reduction=1,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.tensor(3.))
class NllLossModule_backwardMeanWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=1,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.rand(4), torch.tensor(3.))
class NllLossModule_backwardSum(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=None,
reduction=2,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.tensor(3.))
class NllLossModule_backwardSumWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=2,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
torch.rand(4), torch.tensor(3.))
class NllLossModule_backward1D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=None,
reduction=0,
ignore_index=10,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.tensor(3.))
class NllLossModule_backward1DWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=0,
ignore_index=10,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.rand(3), torch.tensor(3.))
class NllLossModule_backward1DMean(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=None,
reduction=1,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.tensor(3.))
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=1,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.rand(3), torch.tensor(3.))
class NllLossModule_backward1DSum(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=None,
reduction=2,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.tensor(3.))
class NllLossModule_backward1DSumWeight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, grad_output, input, target, weight, total_weight):
return torch.ops.aten.nll_loss_backward(grad_output,
input,
target=target,
weight=weight,
reduction=2,
ignore_index=1,
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.rand(3), torch.tensor(3.))