//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static void createLinalgPayloadCalculationForGatherOps( OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index, int64_t dim, int64_t outputRank) { SmallVector indices; for (int i = 0; i < inputRank; i++) { if (i == dim) { indices.push_back(castIntToIndex(b, loc, index)); } else { // `outputRank` might be larger than `inputRank`. The `linalg::IndexOp` // takes in the dimension of the output. Add `inputDimOffset` to // related to the correct dimension of the output for dimension larger // than the given `dim`. int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank; indices.push_back(b.create(loc, i + inputDimOffset)); } } // Assert index < input.sizes[dim] Value indexLTInputDim = b.create( loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), getDimOp(b, loc, input, dim)); b.create( loc, indexLTInputDim, b.getStringAttr("index must be smaller than dim size")); // Assert index >= 0 Value cst0 = b.create(loc, b.getZeroAttr(index.getType())); Value indexGEThanZero = b.create(loc, arith::CmpIPredicate::sge, index, cst0); b.create(loc, indexGEThanZero, b.getStringAttr("index must be larger or equal to 0")); Value extract = b.create(loc, input, indices); b.create(loc, extract); } namespace { class ConvertAtenGatherOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value dimValue = op.dim(); int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); Value indices = adaptor.index(); Value self = adaptor.self(); RankedTensorType newResultTy = getTypeConverter()->convertType(op.getType()).cast(); int64_t rank = newResultTy.getRank(); SmallVector sizes = getTensorSizes(rewriter, loc, indices); Value result = createZeroInitTensor(rewriter, loc, sizes, newResultTy.getElementType()); SmallVector affineMaps(2, rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes(rank, getParallelIteratorTypeName()); auto genericOp = rewriter .create( loc, result.getType(), indices, result, affineMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { auto index = args[0]; createLinalgPayloadCalculationForGatherOps( b, loc, self, rank, index, dim, rank); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultTy, genericOp); return success(); } }; } // namespace namespace { class ConvertAtenEmbeddingOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value weight = adaptor.weight(); Value indices = adaptor.indices(); RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); auto weightTy = weight.getType().cast(); if (weightTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "weight must be rank 2"); Value embeddingDim = getDimOp(rewriter, loc, weight, 1); Type elemTy = weightTy.getElementType(); SmallVector sizes = getTensorSizes(rewriter, loc, indices); sizes.push_back(embeddingDim); int64_t resultRank = sizes.size(); auto indicesTy = indices.getType().cast(); int64_t indicesRank = indicesTy.getRank(); SmallVector indicesExprs; for (int i = 0; i < indicesRank; i++) indicesExprs.push_back(rewriter.getAffineDimExpr(i)); auto indicesAffineMap = AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, indicesExprs, op->getContext()); SmallVector indexingMaps = { indicesAffineMap, rewriter.getMultiDimIdentityMap(resultRank), }; SmallVector iteratorTypes(sizes.size(), getParallelIteratorTypeName()); Value initTensor = rewriter.create(loc, sizes, elemTy); Value embeddingResult = rewriter .create( loc, initTensor.getType(), indices, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value index = args[0]; createLinalgPayloadCalculationForGatherOps( b, loc, weight, weightTy.getRank(), index, /*dim=*/0, resultRank); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, embeddingResult); return success(); } }; } // namespace namespace { // AtenEmbeddingPaddingIdxOp // SUM mode == integer 0 // Sums bags of embeddings together from a weight tensor based on an index and // offset Vector. Example arguments weight = [[1, 3, 5, 3], // [3, 4, 2, 1], // [2, 2, 3, 2], // [0, 4, 2, 1]] // // indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1] // offsets = [0, 3, 5] // // output_tensor = initZeroTensor(offsets_length, embedding_size) // // for i in range(offsets_length): <- dim0 // for j in range(indices_length): <- dim1 // for k in range(embedding_size): <- dim2 // if(offsets[i] <= j and j < offsets[i+1]): // output_tensor[i][k] = output_tensor[i][k] + // weight[indices[j]][k] // else: // break // // Indexing maps for linalg::Generic ops // // // indices_indexing_map = (d0, d1, d2) -> (d1) // offset_indexing_map = (d0, d1, d2) -> (d0) // output_indexing_map = (d0, d1, d2) -> (d0, d2) // // TODO: Find an optimal lowering. // current lowering is not optimal for bags of large embeddings. // Since it traverses the output tensor multiple times. // // class ConvertAtenEmbeddingBagPaddingIdxOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); auto context = op->getContext(); Value weight = adaptor.weight(); Value indices = adaptor.indices(); Value offsets = adaptor.offsets(); Value scaleGradByFreq = op.scale_grad_by_freq(); Value mode = op.mode(); Value sparse = op.sparse(); Value includeLastOffset = op.include_last_offset(); bool scaleGradByFreqBool; if (!matchPattern(scaleGradByFreq, m_TorchConstantBool(&scaleGradByFreqBool))) { return rewriter.notifyMatchFailure( op, "scale_grad_by_freq is expected to be a constant boolean value."); } if (scaleGradByFreqBool) { return rewriter.notifyMatchFailure( op, "Unimplemented: scale_grad_by_freq=True."); } int64_t modeInt; if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { return rewriter.notifyMatchFailure( op, "mode is expected to be a constant integer value."); } if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { return rewriter.notifyMatchFailure( op, "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); } bool isSparse; if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) { return rewriter.notifyMatchFailure( op, "sparse is expected to be a constant boolean value."); } if (isSparse) { return rewriter.notifyMatchFailure( op, "Unimplemented: Sparse mode is not supported yet for EmbeddingBag."); } bool discardLastOffset; if (!matchPattern(includeLastOffset, m_TorchConstantBool(&discardLastOffset))) { return rewriter.notifyMatchFailure( op, "include_last_offset is expected to be a constant boolean value."); } auto weightTy = weight.getType().cast(); if (weightTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "weight must be rank 2"); auto indicesTy = indices.getType().cast(); if (indicesTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "indices must be a vector"); auto offsetsTy = offsets.getType().cast(); if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "offsets much be a vector"); Type weightElemTy = weightTy.getElementType(); int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank(); SmallVector indicesExpr; indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); auto indicesIndexingMap = AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, indicesExpr, context); SmallVector offsetsExpr; offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); auto offsetIndexingMap = AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, offsetsExpr, context); SmallVector outputExpr; outputExpr.push_back(mlir::getAffineDimExpr(0, context)); outputExpr.push_back(mlir::getAffineDimExpr(2, context)); auto outputIndexingMap = AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, outputExpr, context); SmallVector indexingMaps = { indicesIndexingMap, offsetIndexingMap, outputIndexingMap, }; // Reduce along the indices dim SmallVector iteratorTypes({getParallelIteratorTypeName(), getReductionIteratorTypeName(), getParallelIteratorTypeName()}); Value embeddingDim = getDimOp(rewriter, loc, weight, 1); Value initTensor; Value offsetsLength; Value indicesLength; if (!discardLastOffset) { SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), embeddingDim}; initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); offsetsLength = getDimOp(rewriter, loc, offsets, 0); indicesLength = getDimOp(rewriter, loc, indices, 0); } else { return rewriter.notifyMatchFailure( op, "Unimplemented: include last offset is not yet " "supported for EmbeddingBag."); } Value embeddingBagResult = rewriter .create( loc, initTensor.getType(), ValueRange{indices, offsets}, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value indexInIndices = args[0]; Value offsetsI = args[1]; Value initTensorElem = args[2]; Value indexI = b.create(loc, /*value=*/0); Value indexIToInt = castIndexToInt64(b, loc, indexI); Value one = getConstant( b, loc, 1, mlir::IntegerType::get(getContext(), 64, IntegerType::Signless)); Value offsetIndexPlusOneInt = b.create(loc, indexIToInt, one); Value offsetIndexPlusOne = castIntToIndex(b, loc, offsetIndexPlusOneInt); Value checkLast = b.create( loc, arith::CmpIPredicate::eq, castIndexToInt64(b, loc, offsetsLength), offsetIndexPlusOneInt); Value nextOffset = b.create( loc, checkLast, castIndexToInt64(b, loc, indicesLength), b.create(loc, offsets, offsetIndexPlusOne)); Value indicesIndex = castIndexToInt64( b, loc, b.create(loc, /*value=*/1)); Value offsetLessThanIndicesIndex = b.create( loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex); Value offsetEqualToIndicesIndex = b.create( loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); Value offsetLessThanOrEqualToIndicesIndex = b.create(loc, offsetLessThanIndicesIndex, offsetEqualToIndicesIndex); Value indicesIndexLessThanNextOffset = b.create(loc, arith::CmpIPredicate::slt, indicesIndex, nextOffset); Value indicesIndexWithinBounds = b.create( loc, offsetLessThanOrEqualToIndicesIndex, indicesIndexLessThanNextOffset); SmallVector indexIntoWeight; indexIntoWeight.push_back( castIntToIndex(b, loc, indexInIndices)); indexIntoWeight.push_back( b.create(loc, /*value=*/2)); Value weightElem = b.create( loc, weight, indexIntoWeight); Value addResult = b.create(loc, weightElem, initTensorElem); Value select = b.create(loc, indicesIndexWithinBounds, addResult, initTensorElem); b.create(loc, select); }) .getResult(0); // cast outputType. auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); Value castedEmbeddingBagResult = rewriter.create(loc, restulType0, embeddingBagResult); // offset2 tensor, this should be an empty tensor for the sum mode SmallVector offsetResultSize; Type offsetElemTy = offsetsTy.getElementType(); Value zeroDim = rewriter.create(loc, /*value=*/0); offsetResultSize.push_back(zeroDim); Value offsetResult = rewriter.create( loc, offsetResultSize, offsetElemTy); auto resultType1 = typeConverter->convertType(op->getResult(1).getType()); Value castedOffsetResult = rewriter.create(loc, resultType1, offsetResult); SmallVector offsetSize = getTensorSizes(rewriter, loc, offsets); // bagsize, vector of size offset with zeros, I think this is always just // a vector of zeros in the sum mode Value bagSize = createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); auto resultType2 = typeConverter->convertType(op->getResult(2).getType()); Value castedBagSizeResult = rewriter.create(loc, resultType2, bagSize); // max indices, vector of size offset with zeros, this is also always a // vector of zeros in the sum mode. Its mainly used in the max mode. Value indicesOut = createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); auto resultType3 = typeConverter->convertType(op->getResult(3).getType()); Value castedMaxIndices = rewriter.create(loc, resultType3, indicesOut); rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, castedBagSizeResult, castedMaxIndices}); return success(); } }; } // namespace namespace { // Let's say we have an input tensor: initialized with some random values of // size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an // integer argument dim = 1. The size of the output tensor will be [4, 2, 6]. // The approach is as follows: // // for i in range(input.size[0]) // for j in range(index.size[0]) // for k in range(input.size[2]) // indexValue = index[j] // output[i,j,k] = input[i,indexValue,k] class ConvertAtenIndexSelectOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.self(); Value indices = adaptor.index(); RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); unsigned inputRank = inputType.getRank(); int64_t dimInt; if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) return op->emitError("unimplemented: dim is not constant"); SmallVector resultShape = getTensorSizes(rewriter, loc, input); resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; Value initTensor = rewriter.create(loc, resultShape, elementType); SmallVector resultExpr; AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt); SmallVector iteratorTypes; for (unsigned i = 0; i < inputRank; i++) { resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); Value finalRes = rewriter .create( loc, initTensor.getType(), ValueRange{indices}, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value index = rewriter.create( loc, rewriter.getIndexType(), args[0]); SmallVector indexTarget; for (unsigned i = 0; i < inputRank; i++) indexTarget.push_back(b.create(loc, i)); indexTarget[dimInt] = index; Value extractedElement = b.create(loc, input, indexTarget); b.create(loc, extractedElement); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); } }; } // namespace // IndexTensor for multiple input tensors broadcasts their shapes to a common // shape and then replaces the indexed dims with the indices given by the // indexing tensors: // x[i_1, i_2, ..., i_M] = result // result[...] = x[i_1[...], i_2[...], ..., i_M[...]] // // where the result shape is computed as follows: // 1. broadcast i_1, i_2, ..., i_M to a common shape // 2. if i_1, i_2, ..., i_M is not contiguous, transpose the broadcasted // shape to the beginning of the result shape, while removing the // unchanged dims (marked by None) // 3. Otherwise replace the indexed dims with the broadcasted shape // // e.g. x: [2, 3] // x[[4], [6, 1]] -> x[6, 4] namespace { class ConvertAtenIndexTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.self(); Value indices = op.indices(); SmallVector indicesTuple; if (!getListConstructElements(indices, indicesTuple)) { return rewriter.notifyMatchFailure( op, "unimplemented: the indices list is not from a list construct"); } SmallVector indicesVal = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple); // Identify the indices with non-None index tensors and determine if they // are contiguous within the input list. SmallVector indexTensorDims; SmallVector indexTensors; bool contiguous = true; for (auto i : llvm::seq(0, (int)indicesVal.size())) { Value index = indicesVal[i]; if (!index || failed(checkNotNone(rewriter, op, index))) continue; if (!indexTensorDims.empty() && indexTensorDims.back() != i - 1) contiguous = false; indexTensorDims.push_back(i); indexTensors.push_back(index); } if (indexTensors.empty()) { return rewriter.notifyMatchFailure( op, "aten.index.Tensor: index tensor must not be None"); } RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); int inputRank = inputType.getRank(); int resultRank = resultType.getRank(); int firstIndexDim = indexTensorDims[0]; int replacedIndexCount = indexTensorDims.size(); int64_t startIndex = contiguous ? firstIndexDim : 0; // Currently we only support statically sized index tensors or dynamic size // index tensors without overlapping dynamic dims when there is more than // one index tensor. // TODO: Add support for dynamic size index tensors with overlapping // dynamic dims. SmallVector broadcastedIndexShape; if (indexTensors.size() > 1) { int maxRank = -1; for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = indexTensor.getType().cast(); maxRank = std::max(maxRank, (int)indexTensorType.getRank()); } // Because we are assuming static shapes, we can get the shape of the // broadcasted index tensors from the shape refinement pass auto refinedResultShape = resultType.getShape(); for (auto i : llvm::seq(startIndex, startIndex + maxRank)) { auto resultDimSize = refinedResultShape[i]; if (ShapedType::isDynamic(resultDimSize)) { SmallVector dynamicDims; int64_t staticDimSize = -1; for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = indexTensor.getType().cast(); int64_t indexTensorRank = indexTensorType.getRank(); if ((maxRank - indexTensorRank) > (i - startIndex)) continue; int64_t dim = i - startIndex - maxRank + indexTensorRank; if (ShapedType::isDynamic(indexTensorType.getShape()[dim])) dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim)); else staticDimSize = std::max(staticDimSize, indexTensorType.getShape()[dim]); } if (dynamicDims.size() >= 2) return rewriter.notifyMatchFailure( op, "unimplemented: index tensors with overlapping dynamic dims"); if (staticDimSize > 1) { Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, rewriter.getIndexType()); auto equalToRunning = rewriter.create( loc, arith::CmpIPredicate::eq, cstStaticDimSize, dynamicDims[0]); rewriter.create(loc, equalToRunning, "mismatched size for broadcast"); } broadcastedIndexShape.push_back(dynamicDims[0]); } else { broadcastedIndexShape.push_back(getConstant( rewriter, loc, resultDimSize, rewriter.getIndexType())); } } } else { // For a single indexing tensor we can simply use its (dynamic) sizes broadcastedIndexShape = getTensorSizes(rewriter, loc, indexTensors.front()); } // This result shape calculation assumes that there is only one // index tensor, or all of the index tensors are statically shaped. int broadcastRank = broadcastedIndexShape.size(); SmallVector resultShape; if (contiguous) { for (auto i : llvm::seq(0, firstIndexDim)) { resultShape.push_back(getDimOp(rewriter, loc, input, i)); } resultShape.append(broadcastedIndexShape); for (auto i : llvm::seq((int)resultShape.size(), resultRank)) { resultShape.push_back(getDimOp(rewriter, loc, input, i - broadcastRank + replacedIndexCount)); } } else { resultShape.append(broadcastedIndexShape); int j = 0; for (auto i : llvm::seq(0, inputRank)) { if (j < replacedIndexCount && i == indexTensorDims[j]) { j++; continue; } resultShape.push_back(getDimOp(rewriter, loc, input, i)); } } // Initialize the indexing maps for the generic op. Because we are assuming // static shapes for the indexing tensors when there are more than 1, we can // safely map all size 1 dims to 0 in the corresponding affine maps. // TODO: For dynamic shapes, we have to either broadcast the index tensors // to a common shape or introduce some form of control flow. Value initTensor = rewriter.create(loc, resultShape, elementType); SmallVector indexingMaps; SmallVector iteratorTypes; for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = indexTensor.getType().cast(); auto indexTensorShape = indexTensorType.getShape(); int rank = indexTensorShape.size(); SmallVector indicesExpr; for (auto dim : llvm::seq(0, rank)) { if (indexTensorShape[dim] == 1) { indicesExpr.push_back(rewriter.getAffineConstantExpr(0)); continue; } indicesExpr.push_back( rewriter.getAffineDimExpr(startIndex + broadcastRank - rank + dim)); } indexingMaps.push_back( AffineMap::get(resultRank, 0, indicesExpr, op->getContext())); } SmallVector resultExpr; for (auto i : llvm::seq(0, resultRank)) { resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } indexingMaps.push_back( AffineMap::get(resultRank, 0, resultExpr, op->getContext())); Value finalRes = rewriter .create( loc, initTensor.getType(), indexTensors, initTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector extractionIndices; if (contiguous) { for (auto i : llvm::seq(0, firstIndexDim)) { extractionIndices.push_back( b.create(loc, i)); } for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { extractionIndices.push_back( castIntToIndex(b, loc, args[i])); } for (auto i : llvm::seq((int)extractionIndices.size(), inputRank)) { extractionIndices.push_back(b.create( loc, i + broadcastRank - replacedIndexCount)); } } else { int indexCount = 0, unchanged = 0; for (auto i : llvm::seq(0, inputRank)) { if (indexCount < replacedIndexCount && i == indexTensorDims[indexCount]) { extractionIndices.push_back( castIntToIndex(b, loc, args[indexCount++])); continue; } extractionIndices.push_back(b.create( loc, broadcastRank + unchanged)); unchanged++; } } Value extractedElement = b.create( loc, input, extractionIndices); b.create(loc, extractedElement); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); } }; } // namespace void mlir::torch::torch_to_linalg:: populateIndirectDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }