add argmax lowering

Add argmax lowering from torch to linalg
pull/364/head snapshot-20211013.20
dan 2021-09-07 17:18:10 +00:00 committed by Yi Zhang
parent 19e9fc4ef1
commit 7750d2173a
7 changed files with 277 additions and 12 deletions

View File

@ -0,0 +1,66 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a)
@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, dim=1)
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, 0, True)
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

View File

@ -34,6 +34,7 @@ from . import batchnorm
from . import quantized_models from . import quantized_models
from . import elementwise from . import elementwise
from . import reduction from . import reduction
from . import argmax
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [
let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)";
} }
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)";
}
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -800,7 +800,7 @@ public:
// of *internal* compiler invariants, and for a user manifests as a compiler // of *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the // crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a // invalid op before the verifier gets to see it -- also release builds of a
// mature copmiler usually have the verifier turned off for compile time // mature compiler usually have the verifier turned off for compile time
// reasons). // reasons).
// //
// The compiler cannot crash even if the user wrote an erroneous program! // The compiler cannot crash even if the user wrote an erroneous program!
@ -1141,12 +1141,161 @@ static Value createLinalgPayloadCalculationForReduceOp(
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) && if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>()) elementType.isa<mlir::FloatType>())
return b.create<AddFOp>(loc, payloadArgs); return b.create<AddFOp>(loc, payloadArgs);
op->emitError("unimplemented lowering in " op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForReduceOp"); "createLinalgPayloadCalculationForReduceOp");
return nullptr; return nullptr;
} }
namespace {
// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic
// op, producing two output buffers.
//
// The first output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
//
// The second output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type. After being
// populated by indexed_generic, this buffer is disgarded as only the index is
// requested.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenArgmaxOp : public OpConversionPattern<AtenArgmaxOp> {
public:
using OpConversionPattern<AtenArgmaxOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = argmaxOp.getLoc();
AtenArgmaxOp::Adaptor adaptor(operands);
Value input = adaptor.self();
RankedTensorType resultType =
getTypeConverter()
->convertType(argmaxOp.getResult().getType())
.cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
if (!outElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires integer-like result type");
bool keepDim = false;
if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure();
int64_t dim;
if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) {
if (!argmaxOp.dim().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires int or NoneType value for Dim");
// For pytorch, if the value of Dim is None, argmax
// returns the index of the max value of the flattened input tensor,
// so here we flatten the input tensor.
SmallVector<ReassociationIndices> reassociation(1);
for (auto i : llvm::seq<int64_t>(0, inputType.getRank()))
reassociation[0].push_back(i);
input = rewriter.create<linalg::TensorCollapseShapeOp>(
argmaxOp->getLoc(), input, reassociation);
// Becomes 0 for flattened tensor.
dim = 0;
// Recast to fix shape.
inputType = input.getType().cast<RankedTensorType>();
}
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
argmaxOp,
"aten.arg_max to linalg.* requires Float input element type");
}
// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<mlir::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize);
} else if (keepDim)
resultShape.push_back(c1);
}
// First fill the output buffer for the index.
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, outElementType);
// Second fill the output buffer for the running max.
Value initTensorMax =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, inElementType)
.result();
FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType,
APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
Value fillValueMax = rewriter.create<ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
.result();
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
if (unsigned(dim) == size.index()) {
iteratorTypes.push_back(getReductionIteratorTypeName());
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(getParallelIteratorTypeName());
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
}
}
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc,
ArrayRef<Type>({filledTensorIdx.getType(), filledTensorMax.getType()}),
input, ValueRange({filledTensorIdx, filledTensorMax}), maps,
iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldIndex = blockArgs[1];
Value oldValue = blockArgs[2];
Value newIndex = rewriter.create<IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));
Value predicate;
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<mlir::CmpFOp>(
nestedLoc, CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
// This cast is required to fix the shape in the case of keepDim=True
rewriter.replaceOpWithNewOp<tensor::CastOp>(argmaxOp, resultType,
linalgOp.getResult(0));
return success();
}
};
} // namespace
namespace { namespace {
// Converts an elementwise op. // Converts an elementwise op.
@ -1896,6 +2045,8 @@ public:
patterns.add<ConvertAtenGatherOp>(typeConverter, context); patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>(); target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context); patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -138,7 +138,6 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
sizes.push_back(-1); sizes.push_back(-1);
continue; continue;
} }
int64_t size; int64_t size;
auto optionalInt = parser.parseOptionalInteger(size); auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) { if (optionalInt.hasValue()) {

View File

@ -270,6 +270,8 @@ public:
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) { } else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(), return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands); meanDim.keepdim(), operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
return visitAtenArgmaxOp(argmax, operands);
} else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) { } else if (auto anyDim = dyn_cast<AtenAnyDimOp>(op)) {
return visitAtenAnyDimOp(anyDim, operands); return visitAtenAnyDimOp(anyDim, operands);
} else if (auto view = dyn_cast<AtenViewOp>(op)) { } else if (auto view = dyn_cast<AtenViewOp>(op)) {
@ -397,6 +399,9 @@ private:
Operation *op, Value dim, Value keepdim, Operation *op, Value dim, Value keepdim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult ChangeResult
visitAtenArgmaxOp(AtenArgmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenAnyDimOp(AtenAnyDimOp op, visitAtenAnyDimOp(AtenAnyDimOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy> template <typename OpTy>
@ -733,8 +738,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype; knowledge.dtype = input.dtype;
llvm::SmallVector<int64_t> dimList; llvm::SmallVector<int64_t> dimList;
bool keepdimBool; bool keepDim;
if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) { if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
knowledge.hasSizes = true; knowledge.hasSizes = true;
int64_t inputRank = input.sizes.size(); int64_t inputRank = input.sizes.size();
// TODO: This is not safe. Need to check the list users and use aliasing // TODO: This is not safe. Need to check the list users and use aliasing
@ -745,7 +750,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
DenseSet<int64_t> dimSet(dimList.begin(), dimList.end()); DenseSet<int64_t> dimSet(dimList.begin(), dimList.end());
for (auto en : llvm::enumerate(input.sizes)) { for (auto en : llvm::enumerate(input.sizes)) {
if (dimSet.contains(en.index())) { if (dimSet.contains(en.index())) {
if (keepdimBool) if (keepDim)
knowledge.sizes.push_back(1); knowledge.sizes.push_back(1);
} else { } else {
knowledge.sizes.push_back(en.value()); knowledge.sizes.push_back(en.value());
@ -753,12 +758,39 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
} }
} else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) { } else if (auto listConstruct = dim.getDefiningOp<PrimListConstructOp>()) {
auto sizes = listConstruct.elements(); auto sizes = listConstruct.elements();
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(), knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(),
kUnknownSize); kUnknownSize);
} }
} }
return getLatticeElement(op->getResult(0)).join(knowledge); return getLatticeElement(op->getResult(0)).join(knowledge);
} }
ChangeResult TypeAnalyzer::visitAtenArgmaxOp(
AtenArgmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
int64_t dim;
bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) {
if (keepDim)
knowledge.sizes[dim] = 1;
else
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
} else if (op.dim().getType().isa<IntegerType>())
knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1,
kUnknownSize);
}
// If dim is no kind of Integer, keepDim is ignored,
// and the result will bea rank-0 tensor
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenAnyDimOp( ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenAnyDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
@ -767,22 +799,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp(
ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype; knowledge.dtype = input.dtype;
int64_t dim; int64_t dim;
bool keepdimBool; bool keepDim;
if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) { if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) {
int64_t inputRank = input.sizes.size(); int64_t inputRank = input.sizes.size();
knowledge.hasSizes = true; knowledge.hasSizes = true;
if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
knowledge.sizes = input.sizes; knowledge.sizes = input.sizes;
dim = toPositiveDim(dim, inputRank); dim = toPositiveDim(dim, inputRank);
if (isValidDim(dim, inputRank)) { if (isValidDim(dim, inputRank)) {
if (keepdimBool) if (keepDim)
knowledge.sizes[dim] = 1; knowledge.sizes[dim] = 1;
else else
knowledge.sizes.erase(knowledge.sizes.begin() + dim); knowledge.sizes.erase(knowledge.sizes.begin() + dim);
} }
} else { } else {
knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1, knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize);
kUnknownSize);
} }
} }
return getLatticeElement(op->getResult(0)).join(knowledge); return getLatticeElement(op->getResult(0)).join(knowledge);

View File

@ -510,6 +510,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")