From ab7aa01b1eed74c13ef2f4af5e9590e72330bc58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Sun, 25 Sep 2022 22:07:46 +0800 Subject: [PATCH] [MHLO] Add torch-to-mhlo e2e support for aten.gather op (#1410) * Add torch-to-mhlo e2e support for aten.gather op * Add more e2e tests for torch.aten.gather op --- e2e_testing/xfail_sets.py | 4 + lib/Conversion/TorchToMhlo/Gather.cpp | 91 +++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/basic.py | 42 +++++++++ 3 files changed, 137 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 6ea58d3ab..14d3fd75e 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -24,6 +24,10 @@ EAGER_MODE_XFAIL_SET = { MHLO_PASS_SET = { "BroadcastToIdentityCaseStaticModule_basic", + "GatherStaticModule_basic", + "GatherModule_basic", + "Gather2DInputModdule_basic", + "GatherRandomIndexModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", diff --git a/lib/Conversion/TorchToMhlo/Gather.cpp b/lib/Conversion/TorchToMhlo/Gather.cpp index a1185c2c1..1b1863347 100644 --- a/lib/Conversion/TorchToMhlo/Gather.cpp +++ b/lib/Conversion/TorchToMhlo/Gather.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -166,6 +167,95 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenGatherOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.self(); + Value index = adaptor.index(); + auto inputType = input.getType().cast(); + auto indexType = index.getType().cast(); + auto indexElemType = indexType.getElementType(); + + if (indexType.getRank() != inputType.getRank()) { + return op.emitError("`index` and `input` param should have the same rank"); + } + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "only constant int `dim` param supported"); + } + dim = toPositiveDim(dim, inputType.getRank()); + if (!isValidDim(dim, inputType.getRank())) { + return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); + } + + bool sparseGrad = false; + if (!matchPattern(op.sparse_grad(), m_TorchConstantBool(&sparseGrad))) { + return rewriter.notifyMatchFailure( + op, "only constant boolean `sparse_grad` param supported"); + } + + auto options = getOptions(); + auto indexShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + if (failed(indexShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dim sizes of `index` param"); + } + auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + auto toConcatIndexShapeValueVec = *indexShapeInfo; + toConcatIndexShapeValueVec.push_back(one); + auto toConcatIndexShape = + rewriter.create(loc, toConcatIndexShapeValueVec); + + auto indexShape = indexType.getShape(); + SmallVector toConcatIndexShapeVec(indexShape.begin(), + indexShape.end()); + toConcatIndexShapeVec.push_back(1); + RankedTensorType toConcatIndexType = + RankedTensorType::get(toConcatIndexShapeVec, indexElemType); + + SmallVector toConcat; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (i == dim) { + toConcat.push_back(rewriter.create( + loc, toConcatIndexType, index, toConcatIndexShape)); + } else { + toConcat.push_back(rewriter.create( + loc, toConcatIndexType, toConcatIndexShape, + rewriter.getI64IntegerAttr(i))); + } + } + auto gatherIndicies = rewriter.create( + loc, toConcat, static_cast(inputType.getRank())); + SmallVector sliceSizes(inputType.getRank(), 1); + + int64_t indexVecDim = inputType.getRank(); + SmallVector collapsedDims; + SmallVector startIndexMap; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } + + auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/{}, + /*collapsedSliceDims=*/collapsedDims, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + rewriter.replaceOpWithNewOp( + op, input, gatherIndicies, dimsAttr, + rewriter.getI64TensorAttr(sliceSizes)); + return success(); +} + void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { @@ -176,5 +266,6 @@ void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); #undef INSERT_ATENOP_PATTERN } diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 68c18456f..4c21c74cd 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -643,6 +643,48 @@ def GatherModule_basic(module, tu: TestUtils): # ============================================================================== +class GatherRandomIndexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, tensor, indices): + return torch.gather(tensor, 1, indices) + +@register_test_case(module_factory=lambda: GatherRandomIndexModule()) +def GatherRandomIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3)) + + +# ============================================================================== + + +class Gather2DInputModdule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, tensor, indices): + return torch.gather(tensor, 1, indices) + +@register_test_case(module_factory=lambda: Gather2DInputModdule()) +def Gather2DInputModdule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]])) + + +# ============================================================================== + + class GatherStaticModule(torch.nn.Module): def __init__(self):