//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static void createLinalgPayloadCalculationForGatherOps( OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index, int64_t dim, int64_t outputRank) { SmallVector indices; for (int i = 0; i < inputRank; i++) { if (i == dim) { indices.push_back(castIntToIndex(b, loc, index)); } else { // `outputRank` might be larger than `inputRank`. The `linalg::IndexOp` // takes in the dimension of the output. Add `inputDimOffset` to // related to the correct dimension of the output for dimension larger // than the given `dim`. int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank; indices.push_back(b.create(loc, i + inputDimOffset)); } } // Assert index < input.sizes[dim] Value indexLTInputDim = b.create( loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), getDimOp(b, loc, input, dim)); b.create( loc, indexLTInputDim, b.getStringAttr("index must be smaller than dim size")); // Assert index >= 0 Value cst0 = b.create(loc, b.getZeroAttr(index.getType())); Value indexGEThanZero = b.create(loc, arith::CmpIPredicate::sge, index, cst0); b.create(loc, indexGEThanZero, b.getStringAttr("index must be larger or equal to 0")); Value extract = b.create(loc, input, indices); b.create(loc, extract); } namespace { class ConvertAtenGatherOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value dimValue = op.dim(); int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); Value indices = adaptor.index(); Value self = adaptor.self(); RankedTensorType newResultTy = getTypeConverter()->convertType(op.getType()).cast(); int64_t rank = newResultTy.getRank(); SmallVector sizes = getTensorSizes(rewriter, loc, indices); Value result = createZeroInitTensor(rewriter, loc, sizes, newResultTy.getElementType()); SmallVector affineMaps(2, rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes(rank, getParallelIteratorTypeName()); auto genericOp = rewriter .create( loc, result.getType(), indices, result, affineMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { auto index = args[0]; createLinalgPayloadCalculationForGatherOps( b, loc, self, rank, index, dim, rank); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultTy, genericOp); return success(); } }; } // namespace namespace { class ConvertAtenEmbeddingOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value weight = adaptor.weight(); Value indices = adaptor.indices(); RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); auto weightTy = weight.getType().cast(); if (weightTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "weight must be rank 2"); Value embeddingDim = getDimOp(rewriter, loc, weight, 1); Type elemTy = weightTy.getElementType(); SmallVector sizes = getTensorSizes(rewriter, loc, indices); sizes.push_back(embeddingDim); int64_t resultRank = sizes.size(); auto indicesTy = indices.getType().cast(); int64_t indicesRank = indicesTy.getRank(); SmallVector indicesExprs; for (int i = 0; i < indicesRank; i++) indicesExprs.push_back(rewriter.getAffineDimExpr(i)); auto indicesAffineMap = AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, indicesExprs, op->getContext()); SmallVector indexingMaps = { indicesAffineMap, rewriter.getMultiDimIdentityMap(resultRank), }; SmallVector iteratorTypes(sizes.size(), getParallelIteratorTypeName()); Value initTensor = rewriter.create(loc, sizes, elemTy); Value embeddingResult = rewriter .create( loc, initTensor.getType(), indices, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value index = args[0]; createLinalgPayloadCalculationForGatherOps( b, loc, weight, weightTy.getRank(), index, /*dim=*/0, resultRank); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, embeddingResult); return success(); } }; } // namespace namespace { // Let's say we have an input tensor: initialized with some random values of // size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an // integer argument dim = 1. The size of the output tensor will be [4, 2, 6]. // The approach is as follows: // // for i in range(input.size[0]) // for j in range(index.size[0]) // for k in range(input.size[2]) // indexValue = index[j] // output[i,j,k] = input[i,indexValue,k] class ConvertAtenIndexSelectOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.self(); Value indices = adaptor.index(); RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); unsigned inputRank = inputType.getRank(); int64_t dimInt; if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) return op->emitError("unimplemented: dim is not constant"); SmallVector resultShape = getTensorSizes(rewriter, loc, input); resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; Value initTensor = rewriter.create(loc, resultShape, elementType); SmallVector resultExpr; AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt); SmallVector iteratorTypes; for (unsigned i = 0; i < inputRank; i++) { resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); Value finalRes = rewriter .create( loc, initTensor.getType(), ValueRange{indices}, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value index = rewriter.create( loc, rewriter.getIndexType(), args[0]); SmallVector indexTarget; for (unsigned i = 0; i < inputRank; i++) indexTarget.push_back(b.create(loc, i)); indexTarget[dimInt] = index; Value extractedElement = b.create(loc, input, indexTarget); b.create(loc, extractedElement); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); } }; } // namespace namespace { class ConvertAtenIndexTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.self(); Value indices = op.indices(); SmallVector indicesTuple; if (!getListConstructElements(indices, indicesTuple)) { return rewriter.notifyMatchFailure( op, "unimplemented: the indices list is not from a list construct"); } SmallVector indicesVal = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple); RankedTensorType inputType = input.getType().cast(); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); unsigned inputRank = inputType.getRank(); unsigned numIndexTensors = indicesTuple.size(); SmallVector inputShape = getTensorSizes(rewriter, loc, input); // Case 1 : When numIndexTensors == 1 and `input` is a 1-d tensor. // TODO: generalize the implementation for other cases. if (numIndexTensors == 1 && inputRank == 1) { if (failed(checkNotNone(rewriter, op, indicesVal[0]))) return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); unsigned resultRank = indicesVal[0].getType().cast().getRank(); SmallVector resultShape; SmallVector indicesExpr, resultExpr; SmallVector iteratorTypes; for (unsigned i = 0; i < resultRank; i++) resultShape.push_back(getDimOp(rewriter, loc, indicesVal[0], i)); Value initTensor = rewriter.create(loc, resultShape, elementType); for (unsigned i = 0; i < resultRank; i++) { indicesExpr.push_back(rewriter.getAffineDimExpr(i)); resultExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); } auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); Value finalRes = rewriter .create( loc, initTensor.getType(), ValueRange{indicesVal[0]}, initTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value indexTarget = castIntToIndex(b, loc, args[0]); Value extractedElement = b.create(loc, input, indexTarget); b.create(loc, extractedElement); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); } else return rewriter.notifyMatchFailure( op, "unimplemented: support for this set of inputs not present"); } }; } // namespace void mlir::torch::torch_to_linalg:: populateIndirectDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }