mirror of https://github.com/llvm/torch-mlir
Add aten::nll_loss_backward op
The lowering of aten::nll_loss_backward op has been added from torch to linalg dialect. The changes has been made as a part of -torch-convert-to-linalg pass. Signed-off-by: Prashant Kumar prashant@nod-labs.compull/558/head
parent
68acc8696e
commit
ccf546f14c
|
@ -60,3 +60,61 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||||
def NllLossModule_ignore_index(module, tu: TestUtils):
|
def NllLossModule_ignore_index(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
||||||
|
|
||||||
|
class NllLossModule_backward(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
|
||||||
|
self=input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=10,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
||||||
|
def NllLossModuleBackward_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
||||||
|
|
||||||
|
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, input, target, total_weight):
|
||||||
|
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
|
||||||
|
self=input,
|
||||||
|
target=target,
|
||||||
|
weight=None,
|
||||||
|
reduction=0,
|
||||||
|
ignore_index=1,
|
||||||
|
total_weight=total_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: NllLossModule_backward_ignore_index())
|
||||||
|
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||||
|
torch.tensor(3.))
|
||||||
|
|
|
@ -1852,6 +1852,26 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
||||||
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `->` qualified(type($output)) `,` qualified(type($total_weight))";
|
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `->` qualified(type($output)) `,` qualified(type($total_weight))";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$target,
|
||||||
|
AnyTorchOptionalTensorType:$weight,
|
||||||
|
Torch_IntType:$reduction,
|
||||||
|
Torch_IntType:$ignore_index,
|
||||||
|
AnyTorchTensorType:$total_weight
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -70,6 +70,15 @@ struct ResultTypeState {
|
||||||
ScalarType result_type(const ResultTypeState &in_state);
|
ScalarType result_type(const ResultTypeState &in_state);
|
||||||
ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
|
ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// These constants control the reduction behavior of the loss functions.
|
||||||
|
// None, Mean and Sum corresponds to "do not reduce", "Mean of losses", and "sum
|
||||||
|
// of losses" respectively.
|
||||||
|
// Source:
|
||||||
|
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/Reduction.h
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
enum Reduction { None, Mean, Sum, END };
|
||||||
|
|
||||||
} // namespace torch_upstream
|
} // namespace torch_upstream
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
@ -28,6 +29,7 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
using namespace mlir::torch::torch_upstream; // For ScalarType and type
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Patterns (as this grows, it should be organized into multiple files)
|
// Patterns (as this grows, it should be organized into multiple files)
|
||||||
|
@ -1323,6 +1325,108 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
|
||||||
|
// for i in range(0, len(input[0])):
|
||||||
|
// for j in range(0, len(input[1])):
|
||||||
|
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
||||||
|
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenNllLossBackwardOp
|
||||||
|
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.self();
|
||||||
|
Value target = adaptor.target();
|
||||||
|
Value weight = adaptor.weight();
|
||||||
|
Value gradOutput = adaptor.grad_output();
|
||||||
|
|
||||||
|
int64_t reduction;
|
||||||
|
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
|
|
||||||
|
// TODO: Handle reduction.
|
||||||
|
if (reduction != Reduction::None)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "reduction along dimensions is not supported.");
|
||||||
|
|
||||||
|
// TODO: Incorporate the weight argument.
|
||||||
|
if (!weight.getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unimplemented, the weight operand is not incorporated.");
|
||||||
|
|
||||||
|
Value ignoreIndex = adaptor.ignore_index();
|
||||||
|
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
||||||
|
|
||||||
|
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||||
|
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||||
|
|
||||||
|
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
|
||||||
|
// required.
|
||||||
|
if (inputRank != 2 || targetRank != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected input and target to be rank 2 and 1 respectively");
|
||||||
|
}
|
||||||
|
RankedTensorType resultType = getTypeConverter()
|
||||||
|
->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
|
Type elementType = resultType.getElementType();
|
||||||
|
|
||||||
|
// Given there is no reduction `grad_input` size is equal to `input` size.
|
||||||
|
auto outputSize = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value initTensor0 =
|
||||||
|
createZeroInitTensor(rewriter, loc, outputSize, elementType);
|
||||||
|
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getZeroAttr(elementType));
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
|
||||||
|
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
|
||||||
|
rewriter.getAffineDimExpr(1)};
|
||||||
|
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
|
||||||
|
getParallelIteratorTypeName()};
|
||||||
|
auto indexingMaps =
|
||||||
|
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
|
||||||
|
Value finalRes =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, resultType, ValueRange{target, gradOutput}, initTensor0,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, rewriter.getIndexType(), args[0]);
|
||||||
|
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
|
||||||
|
|
||||||
|
// The final result is given by:
|
||||||
|
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
|
||||||
|
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::eq, indJ, indTarget);
|
||||||
|
|
||||||
|
// The target index shouldn't be equal to `ignoreIndex`.
|
||||||
|
Value cmpNe = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
|
||||||
|
Value finalPredicate =
|
||||||
|
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
|
||||||
|
Value negate =
|
||||||
|
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
|
||||||
|
Value selectFinal = rewriter.create<mlir::SelectOp>(
|
||||||
|
loc, finalPredicate, negate, zeroVal);
|
||||||
|
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, finalRes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// See comments at in convertMmOp and the heading for this section for general
|
// See comments at in convertMmOp and the heading for this section for general
|
||||||
// considerations. This function needs to be auto-generated.
|
// considerations. This function needs to be auto-generated.
|
||||||
|
@ -4528,6 +4632,8 @@ public:
|
||||||
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenNllLossBackwardOp>();
|
||||||
|
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenIndexSelectOp>();
|
target.addIllegalOp<AtenIndexSelectOp>();
|
||||||
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
||||||
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
||||||
|
|
|
@ -489,6 +489,8 @@ public:
|
||||||
return visitBinaryScalarOp(op, operands);
|
return visitBinaryScalarOp(op, operands);
|
||||||
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
||||||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||||
|
} else if (auto nllBackwardOp = dyn_cast<AtenNllLossBackwardOp>(op)) {
|
||||||
|
return visitAtenNllLossBackwardOp(nllBackwardOp, operands);
|
||||||
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
|
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
|
||||||
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
|
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
|
||||||
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
|
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
|
||||||
|
@ -647,6 +649,9 @@ private:
|
||||||
ChangeResult visitAtenNllLossForwardOp(
|
ChangeResult visitAtenNllLossForwardOp(
|
||||||
AtenNllLossForwardOp op,
|
AtenNllLossForwardOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
ChangeResult visitAtenNllLossBackwardOp(
|
||||||
|
AtenNllLossBackwardOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult visitAtenNativeLayerNormOp(
|
ChangeResult visitAtenNativeLayerNormOp(
|
||||||
AtenNativeLayerNormOp op,
|
AtenNativeLayerNormOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
@ -1188,8 +1193,8 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
||||||
|
|
||||||
if (self.hasSizes &&
|
if (self.hasSizes &&
|
||||||
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
|
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
|
||||||
// reduction == 1 means reduce 1st dim.
|
if (reduction != Reduction::None)
|
||||||
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
|
resultRank -= 1;
|
||||||
}
|
}
|
||||||
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
|
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
|
||||||
outputKnowledge.hasSizes = true;
|
outputKnowledge.hasSizes = true;
|
||||||
|
@ -1199,6 +1204,22 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
||||||
return resultLattice;
|
return resultLattice;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenNllLossBackwardOp(
|
||||||
|
AtenNllLossBackwardOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto self = operands[1]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||||
|
|
||||||
|
knowledge.dtype = self.dtype;
|
||||||
|
if (self.hasSizes) {
|
||||||
|
unsigned resultRank = self.sizes.size();
|
||||||
|
knowledge.sizes.resize(resultRank, kUnknownSize);
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
}
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
|
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
|
||||||
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto operand = operands[0]->getValue();
|
auto operand = operands[0]->getValue();
|
||||||
|
|
|
@ -548,6 +548,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||||
|
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue