mirror of https://github.com/llvm/torch-mlir
Add aten::nll_loss_forward op lowering.
The op lowering has been added as a part of `torch-lower-to-linalg` pass. This takes care of ignore_index but the weight and reduction operand is still to be accounted for. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>pull/464/head snapshot-20211207.130
parent
5c7ce45c4e
commit
977b1b03ea
|
@ -43,6 +43,7 @@ from . import view
|
|||
from . import scalar
|
||||
from . import squeeze
|
||||
from . import slice_like
|
||||
from . import nll_loss
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NllLossModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(self=x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
def NllLossModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
||||
|
||||
|
||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(self=x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
def NllLossModule_ignore_index(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
|
|
@ -1573,6 +1573,25 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
|
|||
let assemblyFormat = "$self `,` $dtype attr-dict `:` type($self) `,` type($dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$target,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
Torch_IntType:$reduction,
|
||||
Torch_IntType:$ignore_index
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$output,
|
||||
AnyTorchTensorType:$total_weight
|
||||
);
|
||||
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
|
||||
}
|
||||
|
||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -1167,6 +1167,102 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Given `input`, `target`, `nll_loss_forward` is given by:
|
||||
// for i in range(0, len(target)):
|
||||
// indi = target[i];
|
||||
// nll_loss_forward[i] = -(input[i][indi]);
|
||||
// TODO: `weight` and `reduction` operands are still to be taken care of.
|
||||
namespace {
|
||||
class ConvertAtenNllLossForwardOp
|
||||
: public OpConversionPattern<AtenNllLossForwardOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.self();
|
||||
Value target = adaptor.target();
|
||||
Value weight = adaptor.weight();
|
||||
|
||||
int64_t reduce_dim;
|
||||
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
|
||||
// TODO: Handle reduction.
|
||||
if (reduce_dim != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "reduction along dimensions is not supported.");
|
||||
|
||||
// TODO: Incorporate the weight argument.
|
||||
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, the weight operand is not incorporated.");
|
||||
|
||||
Value ignoreIndex = adaptor.ignore_index();
|
||||
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
||||
|
||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
||||
|
||||
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
|
||||
if (inputRank != 2 || targetRank != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input and target to be rank 2 and 1 respectively");
|
||||
}
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
Value targetDim = getDimOp(rewriter, loc, target, 0);
|
||||
Value initTensor0 =
|
||||
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
|
||||
Value zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
|
||||
SmallVector<AffineExpr> targetExpr;
|
||||
targetExpr.push_back(rewriter.getAffineDimExpr(0));
|
||||
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
|
||||
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, ValueRange{target}, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), args[0]);
|
||||
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||
|
||||
// The final result is given by:
|
||||
// final_res = (indI == ignoreIndexVal) ? 0 :
|
||||
// input[indI][IndTarget]
|
||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
|
||||
Value result = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{indI, indTarget});
|
||||
Value negate =
|
||||
rewriter.create<arith::NegFOp>(loc, elementType, result);
|
||||
Value selectFinal = rewriter.create<mlir::SelectOp>(
|
||||
loc, cmpEq, zeroVal, negate);
|
||||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// TODO: Update the second result tensor.
|
||||
Value weightUpdated =
|
||||
createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
rewriter.replaceOp(op, {finalRes, weightUpdated});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// See comments at in convertMmOp and the heading for this section for general
|
||||
// considerations. This function needs to be auto-generated.
|
||||
|
@ -3372,6 +3468,8 @@ public:
|
|||
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceTensorOp>();
|
||||
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -454,8 +454,11 @@ public:
|
|||
return visitAtenAddCLikeOp(op, operands);
|
||||
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
|
||||
return visitBinaryScalarOp(scalarOp);
|
||||
}else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
||||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||
}
|
||||
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
// having reached a pessimistic fixpoint.
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
|
@ -580,6 +583,10 @@ private:
|
|||
ChangeResult
|
||||
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
ChangeResult
|
||||
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -927,6 +934,38 @@ ChangeResult TypeAnalyzer::visitAtenSqueezeOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
||||
AtenNllLossForwardOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto outputKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
|
||||
// Contains Knowledge of shape and dtype for the 1st result.
|
||||
outputKnowledge.dtype = self.dtype;
|
||||
int64_t reduction;
|
||||
unsigned resultRank = self.sizes.size();
|
||||
|
||||
// Contains Knowledge of shape and dtype for the 2nd result.
|
||||
auto totalWeightKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
totalWeightKnowledge.dtype = self.dtype;
|
||||
totalWeightKnowledge.sizes.resize(0, kUnknownSize);
|
||||
totalWeightKnowledge.hasSizes = true;
|
||||
|
||||
if (self.hasSizes &&
|
||||
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
|
||||
// reduction == 1 means reduce 1st dim.
|
||||
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
|
||||
}
|
||||
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
|
||||
outputKnowledge.hasSizes = true;
|
||||
auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge);
|
||||
resultLattice |=
|
||||
getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);
|
||||
return resultLattice;
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
||||
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto operand = operands[0]->getValue();
|
||||
|
|
|
@ -523,6 +523,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue