From 7750d2173a42a9f2f09b2c87b4a4144a3e63b407 Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 7 Sep 2021 17:18:10 +0000 Subject: [PATCH] add argmax lowering Add argmax lowering from torch to linalg --- e2e_testing/torchscript/argmax.py | 66 ++++++++ e2e_testing/torchscript/main.py | 1 + .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 ++ .../TorchToLinalg/TorchToLinalg.cpp | 155 +++++++++++++++++- lib/Dialect/Torch/IR/TorchTypes.cpp | 1 - lib/Dialect/Torch/Transforms/RefineTypes.cpp | 49 +++++- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 7 files changed, 277 insertions(+), 12 deletions(-) create mode 100644 e2e_testing/torchscript/argmax.py diff --git a/e2e_testing/torchscript/argmax.py b/e2e_testing/torchscript/argmax.py new file mode 100644 index 000000000..575af604c --- /dev/null +++ b/e2e_testing/torchscript/argmax.py @@ -0,0 +1,66 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class ArgmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.argmax(a) + + +@register_test_case(module_factory=lambda: ArgmaxModule()) +def ArgmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ArgmaxWithDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.argmax(a, dim=1) + +@register_test_case(module_factory=lambda: ArgmaxWithDimModule()) +def ArgmaxModule_with_dim(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ArgmaxKeepDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.argmax(a, 0, True) + +@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule()) +def ArgmaxModule_keepDim(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 0291e3ac2..de5cdcd3e 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -34,6 +34,7 @@ from . import batchnorm from . import quantized_models from . import elementwise from . import reduction +from . import argmax def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 26bd3aebc..5bc50c3e1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1349,6 +1349,22 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [ let assemblyFormat = "$start `,` $end `,` $dtype `,` $layout `,` $device `,` $pin_memory attr-dict `:` type($start) `,` type($end) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `->` type($result)"; } +def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim `,` $keepdim attr-dict `:` type($self) `,` type($dim) `,` type($keepdim) `->` type($result)"; +} + def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ AllowsTypeRefinement ]> { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 62e697889..33acfc3f6 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -800,7 +800,7 @@ public: // of *internal* compiler invariants, and for a user manifests as a compiler // crash in the worst case (such as we try to canonicalize/fold/print the // invalid op before the verifier gets to see it -- also release builds of a - // mature copmiler usually have the verifier turned off for compile time + // mature compiler usually have the verifier turned off for compile time // reasons). // // The compiler cannot crash even if the user wrote an erroneous program! @@ -1141,12 +1141,161 @@ static Value createLinalgPayloadCalculationForReduceOp( if (isa(op) && elementType.isa()) return b.create(loc, payloadArgs); - op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForReduceOp"); return nullptr; } +namespace { +// Aten argmax lowering represents the ArgMax op as an linalg.indexed_generic +// op, producing two output buffers. +// +// The first output buffer contains the index of the found maximum value. It is +// initialized to 0 and is resulting integer type. +// +// The second output buffer contains the maximum value found. It is initialized +// to the minimum representable value of the input element type. After being +// populated by indexed_generic, this buffer is disgarded as only the index is +// requested. +// +// The indexed_generic op updates both the maximum value and index if the +// current value exceeds the running max. +class ConvertAtenArgmaxOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenArgmaxOp argmaxOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + Location loc = argmaxOp.getLoc(); + AtenArgmaxOp::Adaptor adaptor(operands); + Value input = adaptor.self(); + RankedTensorType resultType = + getTypeConverter() + ->convertType(argmaxOp.getResult().getType()) + .cast(); + RankedTensorType inputType = input.getType().cast(); + Type outElementType = resultType.getElementType(); + if (!outElementType.isa()) + return rewriter.notifyMatchFailure( + argmaxOp, + "aten.arg_max to linalg.* requires integer-like result type"); + + bool keepDim = false; + if (!matchPattern(argmaxOp.keepdim(), m_TorchConstantBool(&keepDim))) + return failure(); + + int64_t dim; + if (!matchPattern(argmaxOp.dim(), m_TorchConstantInt(&dim))) { + if (!argmaxOp.dim().getType().isa()) + return rewriter.notifyMatchFailure( + argmaxOp, + "aten.arg_max to linalg.* requires int or NoneType value for Dim"); + // For pytorch, if the value of Dim is None, argmax + // returns the index of the max value of the flattened input tensor, + // so here we flatten the input tensor. + SmallVector reassociation(1); + for (auto i : llvm::seq(0, inputType.getRank())) + reassociation[0].push_back(i); + input = rewriter.create( + argmaxOp->getLoc(), input, reassociation); + // Becomes 0 for flattened tensor. + dim = 0; + // Recast to fix shape. + inputType = input.getType().cast(); + } + Type inElementType = inputType.getElementType(); + if (!inElementType.isa()) { + return rewriter.notifyMatchFailure( + argmaxOp, + "aten.arg_max to linalg.* requires Float input element type"); + } + + // Constant op to account for the reduction along dim. + auto c1 = rewriter.create(loc, /*value=*/1); + SmallVector resultShape; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dim != i) { + auto currentDimSize = rewriter.create(loc, input, i); + resultShape.push_back(currentDimSize); + } else if (keepDim) + resultShape.push_back(c1); + } + // First fill the output buffer for the index. + Value filledTensorIdx = + createZeroInitTensor(rewriter, loc, resultShape, outElementType); + + // Second fill the output buffer for the running max. + Value initTensorMax = + rewriter.create(loc, resultShape, inElementType) + .result(); + + FloatAttr fillValueMaxAttr = rewriter.getFloatAttr( + inElementType, + APFloat::getLargest( + inElementType.cast().getFloatSemantics(), true)); + + Value fillValueMax = rewriter.create(loc, fillValueMaxAttr); + Value filledTensorMax = + rewriter.create(loc, fillValueMax, initTensorMax) + .result(); + + // Create the affine expressions that will be used to + // iterate over the input and output tensors. + // Here we also set the type of iterator: parallel or reduction. + SmallVector exprs; + SmallVector iteratorTypes; + SmallVector resultExprs; + for (auto size : llvm::enumerate(inputType.getShape())) { + exprs.push_back(rewriter.getAffineDimExpr(size.index())); + + if (unsigned(dim) == size.index()) { + iteratorTypes.push_back(getReductionIteratorTypeName()); + // If `keepDim`, create affine map to the first element + // in the current dimension. + if (keepDim) + resultExprs.push_back(rewriter.getAffineConstantExpr(0)); + } else { + iteratorTypes.push_back(getParallelIteratorTypeName()); + resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); + } + } + auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); + auto linalgOp = rewriter.create( + loc, + ArrayRef({filledTensorIdx.getType(), filledTensorMax.getType()}), + input, ValueRange({filledTensorIdx, filledTensorMax}), maps, + iteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value newValue = blockArgs[0]; + Value oldIndex = blockArgs[1]; + Value oldValue = blockArgs[2]; + + Value newIndex = rewriter.create( + nestedLoc, oldIndex.getType(), + rewriter.create(loc, dim)); + + Value predicate; + if (inElementType.isa()) + predicate = rewriter.create( + nestedLoc, CmpFPredicate::OGT, newValue, oldValue); + auto resultMax = rewriter.create(nestedLoc, predicate, + newValue, oldValue); + auto resultIndex = rewriter.create( + nestedLoc, predicate, newIndex, oldIndex); + nestedBuilder.create( + nestedLoc, ValueRange({resultIndex, resultMax})); + }); + + // This cast is required to fix the shape in the case of keepDim=True + rewriter.replaceOpWithNewOp(argmaxOp, resultType, + linalgOp.getResult(0)); + return success(); + } +}; +} // namespace namespace { // Converts an elementwise op. @@ -1896,6 +2045,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index c69042b43..8c6221418 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -138,7 +138,6 @@ Type parseTensorType(MLIRContext *context, DialectAsmParser &parser, sizes.push_back(-1); continue; } - int64_t size; auto optionalInt = parser.parseOptionalInteger(size); if (optionalInt.hasValue()) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 70702a741..9d46ecc06 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -270,6 +270,8 @@ public: } else if (auto meanDim = dyn_cast(op)) { return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(), meanDim.keepdim(), operands); + } else if (auto argmax = dyn_cast(op)) { + return visitAtenArgmaxOp(argmax, operands); } else if (auto anyDim = dyn_cast(op)) { return visitAtenAnyDimOp(anyDim, operands); } else if (auto view = dyn_cast(op)) { @@ -397,6 +399,9 @@ private: Operation *op, Value dim, Value keepdim, ArrayRef *> operands); ChangeResult + visitAtenArgmaxOp(AtenArgmaxOp op, + ArrayRef *> operands); + ChangeResult visitAtenAnyDimOp(AtenAnyDimOp op, ArrayRef *> operands); template @@ -733,8 +738,8 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp( ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = input.dtype; llvm::SmallVector dimList; - bool keepdimBool; - if (matchPattern(keepdim, m_TorchConstantBool(&keepdimBool))) { + bool keepDim; + if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) { knowledge.hasSizes = true; int64_t inputRank = input.sizes.size(); // TODO: This is not safe. Need to check the list users and use aliasing @@ -745,7 +750,7 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp( DenseSet dimSet(dimList.begin(), dimList.end()); for (auto en : llvm::enumerate(input.sizes)) { if (dimSet.contains(en.index())) { - if (keepdimBool) + if (keepDim) knowledge.sizes.push_back(1); } else { knowledge.sizes.push_back(en.value()); @@ -753,12 +758,39 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp( } } else if (auto listConstruct = dim.getDefiningOp()) { auto sizes = listConstruct.elements(); - knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - sizes.size(), + knowledge.sizes.resize(keepDim ? inputRank : inputRank - sizes.size(), kUnknownSize); } } return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenArgmaxOp( + AtenArgmaxOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = ValueKnowledge::getPessimisticValueState(op->getContext()); + knowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); + int64_t dim; + bool keepDim; + if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + int64_t inputRank = input.sizes.size(); + knowledge.hasSizes = true; + if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + knowledge.sizes = input.sizes; + dim = toPositiveDim(dim, inputRank); + if (isValidDim(dim, inputRank)) { + if (keepDim) + knowledge.sizes[dim] = 1; + else + knowledge.sizes.erase(knowledge.sizes.begin() + dim); + } + } else if (op.dim().getType().isa()) + knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, + kUnknownSize); + } + // If dim is no kind of Integer, keepDim is ignored, + // and the result will bea rank-0 tensor + return getLatticeElement(op->getResult(0)).join(knowledge); +} ChangeResult TypeAnalyzer::visitAtenAnyDimOp( AtenAnyDimOp op, ArrayRef *> operands) { @@ -767,22 +799,21 @@ ChangeResult TypeAnalyzer::visitAtenAnyDimOp( ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = input.dtype; int64_t dim; - bool keepdimBool; - if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepdimBool))) { + bool keepDim; + if (matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { int64_t inputRank = input.sizes.size(); knowledge.hasSizes = true; if (matchPattern(op.dim(), m_TorchConstantInt(&dim))) { knowledge.sizes = input.sizes; dim = toPositiveDim(dim, inputRank); if (isValidDim(dim, inputRank)) { - if (keepdimBool) + if (keepDim) knowledge.sizes[dim] = 1; else knowledge.sizes.erase(knowledge.sizes.begin() + dim); } } else { - knowledge.sizes.resize(keepdimBool ? inputRank : inputRank - 1, - kUnknownSize); + knowledge.sizes.resize(keepDim ? inputRank : inputRank - 1, kUnknownSize); } } return getLatticeElement(op->getResult(0)).join(knowledge); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2c3d0bb4f..c8daa0c9b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -510,6 +510,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")