//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static int64_t productReduce(ArrayRef a) { return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); } template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); RankedTensorType inputType = input.getType().template cast(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); int64_t inputRank = inputType.getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value dimSize = inputShape[dim]; Value torchTypeStart = op.getStart(); Value torchTypeEnd = op.getEnd(); Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); if (torchTypeStart.getType().isa() || torchTypeEnd.getType().isa()) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); // We cannot use to positive valid dim as for negative strides we need to // clamp to `-1` so that the full tensor bounds are available: Value end = builtinTypeEnd; if (torchTypeEnd.getType().isa()) { end = dimSize; } else { end = castIntToIndex(rewriter, loc, end); Value endcmp = rewriter.create( loc, arith::CmpIPredicate::slt, end, zero); Value endadd = rewriter.create(loc, end, dimSize); end = rewriter.create(loc, endcmp, endadd, end); endcmp = rewriter.create(loc, arith::CmpIPredicate::slt, end, zero); end = rewriter.create(loc, endcmp, negone, end); endcmp = rewriter.create(loc, arith::CmpIPredicate::sgt, end, dimSize); end = rewriter.create(loc, endcmp, dimSize, end); } // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); Value len = rewriter.create(loc, end, start); // We check the difference between start and end to determine the total size: Value stepcmp = rewriter.create(loc, arith::CmpIPredicate::sge, stepIndex, zero); Value stepsign = rewriter.create(loc, stepcmp, one, negone); Value resultSize = rewriter.create(loc, len, stepIndex); resultSize = rewriter.create(loc, resultSize, stepsign); resultSize = rewriter.create(loc, resultSize, stepIndex); // Clamp the size to [0, ...]: Value szcmp = rewriter.create(loc, arith::CmpIPredicate::slt, resultSize, zero); resultSize = rewriter.create(loc, szcmp, zero, resultSize); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; strides[dim] = stepIndex; return success(); } // Example: // input = tensor([[[0., 1., 2., 3.], // [4., 5., 6., 7.]]]) // torch.ops.aten.reflection_pad1d(input, (3,1)); // padding_left = 3, // padding_right = 1 // output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], // [7., 6., 5., 4., 5., 6., 7., 6.]]]) // Checks: 1) Each of padding_left and padding_right must be non-negative and // less than the size of the last dimension. // Implementation: a) Construct a result tensor of // shape of input tensor except for the last dimension. // The last dimension of the result tensor should be last // dimension of input tensor + left padding size + right // padding size. Initialize result tensor to all zeros // b) Setup affine map to take slice from input tensor of size // left padding starting from // second column onwards as first column is reflection // boundary // c) Reflect the affine map to have resultant slice reflected // d) Take the slice and write from begining in result tensor // e) write the original tensor next into result tensor // f) Setup affine map to take slice from input tensor of right // padding size ending // at second last column as last column is reflection // boundary for right padding // g) Reflect the affine map to have resultant slice reflected // h) Take the slice and write from left padding size + orignal // tensor last dim size // into result tensor // Uses the ideas/code used for AtenReflectionPad2dOp namespace { class ConvertAtenReflectionPad1dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( op, "only constant int padding range is supported"); MLIRContext *context = rewriter.getContext(); Location loc = op.getLoc(); // Lambda Unitility Functions // Create an Integer expression of x + y auto createIAdd = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; // Create an integer expression of x - y auto createISub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 }; Value input = adaptor.getSelf(); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); auto inputType = llvm::cast(input.getType()); auto outputType = llvm::cast( getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); int64_t lastDim = numDims - 1; SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, // inputShape[2] will give 4 Value tileWidth[3], extractOffset[3], insertOffset[3]; tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); tileWidth[PAD_CENTER] = lastDimSize; extractOffset[PAD_LEFT] = one; // The offset for the right hand padding "bar" is: // [right] lastDimSize - (tileWidth[PAD_RIGHT] + one) extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); extractOffset[PAD_CENTER] = zero; insertOffset[PAD_LEFT] = zero; insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; SmallVector resultShape{inputShape}; // Result's last dimension will have size: // lastDimSize + left padding size + right padding size resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); // Helper to reflect/reverse the i-th dimension of an affine map without // symbols. This only works if applied on a tensor for which the // corresponding dimension has a statically known size auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { AffineExpr d = map.getResult(i); return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). // size = 3, lastDim=2, numDims=3 }; SmallVector iteratorTypes{ numDims, utils::IteratorType::parallel}; auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto addTileToResult = [&](PadLocation padPosition) { // Create the tile by extracting a slice from the input tensor. SmallVector extractShape{inputShape}; extractShape[lastDim] = tileWidth[padPosition]; SmallVector extractOffsets(numDims, zero); extractOffsets[lastDim] = extractOffset[padPosition]; Value tile = rewriter.create( loc, input, extractOffsets, extractShape, allOneStrides); auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); // Setup the affine map function to resverse the tile along the horizontal // for left and right slices if (padPosition < PAD_CENTER) { inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); // Take reflected slice as per inputMap tile = rewriter .create( loc, llvm::cast(tile.getType()), tile, tile, ArrayRef({inputMap, idMap}), iteratorTypes, [](OpBuilder &b, Location nestedLoc, ValueRange args) { b.create(nestedLoc, args[0]); }) .getResult(0); } // Insert the tile in the resultTensor SmallVector insertOffsets(numDims, zero); insertOffsets[lastDim] = insertOffset[padPosition]; resultTensor = rewriter.create( loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); }; if (padInts[PAD_LEFT] > 0) addTileToResult(PAD_LEFT); if (padInts[PAD_RIGHT] > 0) addTileToResult(PAD_RIGHT); addTileToResult(PAD_CENTER); rewriter.replaceOpWithNewOp(op, outputType, resultTensor); return success(); } }; } // namespace namespace { // Lower the aten.reflection.pad_2d operator into a sequence of // tensor.extract_slice, linalg.generic, and tensor_insert_slice // operations. // To understand the lowering, consider this pytorch example: // // >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]]) // >>> t // tensor([[[1., 2., 3.], // [4., 5., 6.], // [7., 8., 9.]]]) // >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2]) // tensor([[[5., 4., 5., 6., 5., 4.], // [2., 1., 2., 3., 2., 1.], // [5., 4., 5., 6., 5., 4.], // [8., 7., 8., 9., 8., 7.], // [5., 4., 5., 6., 5., 4.], // [2., 1., 2., 3., 2., 1.]]]) // // The result can be subdivided into "tiles" corresponding to either // the input tensor (in the center) or slices of the input tensor // whose width and height is determined by the padding sizes and which // are reflected through the side of the central input tensor that // they touch. // In the example above, the tiles are: // top left: [[5]] // top center: [[4,5,6]] // top right: [[5,4]] // center left [[2,1],[5,4],[8,7]] // center: copy of the input tensor // center right: [[2,1],[5,4],[8,7]] // bottom left: [[5,4],[2,1]] // center bottom: [[2,3,2]] // center right: [[2,1]] // // The lowering uses a tensor.extract_slice operation to create each tile, // a linalg.generic for the reflection, and a tensor.insert_slice to // insert the tile in the resulting tensor. class ConvertAtenReflectionPad2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( op, "only support constant int pad ranges"); Location loc = op.getLoc(); // Some generic helper functions for creating arithmetic operations. auto createAdd = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; auto createAdds = [&](std::initializer_list values) { assert(values.size() >= 2); return std::accumulate(values.begin() + 1, values.end(), data(values)[0], createAdd); }; auto createSub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; auto createSubs = [&](std::initializer_list values) { assert(values.size() >= 2); return std::accumulate(values.begin() + 1, values.end(), data(values)[0], createSub); }; // Enums for specifying the coordinates of a tile. An "h" prefix // is used to stand for "horizontal" and "v" for "vertical" // throughout. enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 }; enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 }; // Helper functions for obtaining information about the operator's // padding arguments. auto getHPadArgument = [&](PadHLoc l) { assert(l < HCENTER); return padInts[l]; }; auto getVPadArgument = [&](PadVLoc l) { assert(l < VCENTER); return padInts[2 + l]; }; auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) { if (!(h == HCENTER || getHPadArgument(h) > 0)) return false; if (!(v == VCENTER || getVPadArgument(v) > 0)) return false; return true; }; Value input = adaptor.getSelf(); MLIRContext *context = rewriter.getContext(); auto inputType = llvm::cast(input.getType()); auto outputType = llvm::cast( getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); SmallVector inputShape = getTensorSizes(rewriter, loc, input); int64_t hDim = numDims - 1; int64_t vDim = numDims - 2; Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && "Left padding too large"); assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && "Right padding too large"); assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && "Top padding too large"); assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && "Bottom padding too large"); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); Value tileWidth[3]; tileWidth[HCENTER] = hDimSize; for (auto h : {LEFT, RIGHT}) tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); Value tileHeight[3]; tileHeight[VCENTER] = vDimSize; for (auto v : {TOP, BOTTOM}) tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); // Helper to reflect/reverse the i-th dimension of an affine map // without symbols. This only works if applied on a tensor // for which the corresponding dimension has a statically // known size which is good enough since we only apply // it to reflect the padding slices. auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { AffineExpr d = map.getResult(i); return map.replace(d, size - d - 1, numDims, 0); }; // Create output shape and tensor SmallVector resultShape{inputShape}; resultShape[vDim] = createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); resultShape[hDim] = createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); // Construction of the tiles // Example: central left tile // // Let m the width of the left padding as returned by getHPadargument(LEFT) // and n the size of the input tensor's "horizontal" dimension, i.e. // hDimSize. Assume that the subtensor of the input tensor in the relevant // (i.e. last two) dimensions is: // // x_1,1 x_1,2 ... x_1,m // x_2,1 x_2,2 ... x_2,m // . // . // . // x_n,1 x_n,2 ... x_n,m // // The padding tile consists of the columns 2, ..., m + 1 // of the input in reverse order. The first column gets // skipped because this is the column through which the // reflection happens. // // x_1,m x_1,m-1 ... x_1,2 // x_2,m x_1,m-1 ... x_2,2 // . // . // . // x_n,m x_n,m-1 ... x_n,2 // // The tile will be inserted to the left of the copy of the input tensor // in the output tensor, i.e. with horizontal offset 0. // The top padding determines the vertical offset. // Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through // two sides, i.e. their columns and rows must be reversed. // Setup information about the tiles // Compute the offsets for extracting the slice from the // input. We need to skip the row or column through which // the tile should be reflected, if any (none for the center tile). Value extractHOffset[3]; extractHOffset[LEFT] = one; extractHOffset[HCENTER] = zero; extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); Value extractVOffset[3]; extractVOffset[TOP] = one; extractVOffset[VCENTER] = zero; extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); // Compute the horizontal and vertical offsets for inserting // the tiles in the resultTensor. Value insertHOffset[3]; insertHOffset[LEFT] = zero; insertHOffset[HCENTER] = tileWidth[LEFT]; insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]); Value insertVOffset[3]; insertVOffset[TOP] = zero; insertVOffset[VCENTER] = tileHeight[TOP]; insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; SmallVector iteratorTypes{ numDims, utils::IteratorType::parallel}; auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { // Create the tile by extracting a slice from the input tenor. SmallVector extractShape{inputShape}; extractShape[hDim] = tileWidth[horizontalPos]; extractShape[vDim] = tileHeight[verticalPos]; SmallVector extractOffsets(numDims, zero); extractOffsets[hDim] = extractHOffset[horizontalPos]; extractOffsets[vDim] = extractVOffset[verticalPos]; Value tile = rewriter.create( loc, input, extractOffsets, extractShape, allOneStrides); // Reverse the tile along the horizontal, vertical, or both // dimensions. auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); if (shouldHReflect(horizontalPos)) { inputMap = reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); } if (shouldVReflect(verticalPos)) { inputMap = reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); } tile = rewriter .create( loc, llvm::cast(tile.getType()), tile, tile, ArrayRef({inputMap, idMap}), iteratorTypes, [](OpBuilder &b, Location nestedLoc, ValueRange args) { b.create(nestedLoc, args[0]); }) .getResult(0); // Insert the tile in the resultTensor. SmallVector insertOffsets(numDims, zero); insertOffsets[hDim] = insertHOffset[horizontalPos]; insertOffsets[vDim] = insertVOffset[verticalPos]; resultTensor = rewriter.create( loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); }; for (auto v : {TOP, BOTTOM, VCENTER}) for (auto h : {LEFT, RIGHT, HCENTER}) if (shouldCreateTile(v, h)) createTile(v, h); rewriter.replaceOpWithNewOp(op, outputType, resultTensor); return success(); } }; } // namespace namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t startDim; if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&startDim))) return rewriter.notifyMatchFailure(op, "start_dim must be constant"); int64_t endDim; if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim))) return rewriter.notifyMatchFailure(op, "end_dim must be constant"); auto type = adaptor.getSelf().getType().cast(); auto inputRank = type.getRank(); if (inputRank == 1) { // If input rank is equal to 1, then there's no scope for flattening the // input tensor. rewriter.replaceOp(op, adaptor.getSelf()); return success(); } auto resultType = getTypeConverter()->convertType(op.getType()).cast(); if (startDim < 0) startDim += inputRank; if (endDim < 0) endDim += inputRank; if (inputRank == 0) { SmallVector reassociation; if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0)) return rewriter.notifyMatchFailure( op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0"); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), reassociation); return success(); } if (startDim < 0 || startDim >= inputRank || endDim < 0 || endDim >= inputRank || startDim > endDim) return rewriter.notifyMatchFailure( op, "statically invalid flattening dim range"); SmallVector reassociation(resultType.getRank()); int j = 0; for (auto i : llvm::seq(0, inputRank)) { reassociation[j].push_back(i); if (i < startDim || i >= endDim) j++; } Value collapsedTensor = rewriter.create( op->getLoc(), adaptor.getSelf(), reassociation); rewriter.replaceOpWithNewOp(op, resultType, collapsedTensor); return success(); } }; } // namespace // Lower aten.unflatten.int into tensor.expand_shape namespace { class ConvertAtenUnflattenIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenUnflattenIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); BaseTensorType outputTensorType = op.getType().cast(); if (!outputTensorType.hasSizes()) return rewriter.notifyMatchFailure( op, "unimplemented: output must have known sizes"); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); auto inputTensorType = self.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } int inputRank = inputTensorType.getSizes().size(); int outputRank = outputTensorType.getSizes().size(); int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) return rewriter.notifyMatchFailure( op, "unimplemented: requires dim to be constants"); dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); auto sizesOp = op.getSizes().getDefiningOp(); int numSizes = sizesOp.getNumOperands(); SmallVector reassociations(inputRank); if (inputRank > 0) { for (int i = 0; i < dimInt; ++i) reassociations[i].push_back(i); for (int i = 0; i < numSizes; ++i) reassociations[dimInt].push_back(i + dimInt); for (int i = dimInt + numSizes; i < outputRank; ++i) reassociations[i - numSizes + 1].push_back(i); } auto expandTy = getTypeConverter()->convertType(outputTensorType); auto expand = rewriter .create( loc, expandTy, adaptor.getSelf(), reassociations) .getResult(); rewriter.replaceOp(op, expand); return success(); } }; } // namespace namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to /// one `linalg.TensorExpandShape` op for all expanded dimensions and one /// `linalg.TensorCollapseShape` op for all collapsed dimensions. Cases where /// there is neither an expand or collapse of dimensions (e.g. [2, 3] -> [3, 2]) /// is not handled. Additionally, certain dynamic dimension cases rely on naive /// assumptions or aren't supported. /// TODO: Handle all the other cases of `aten.View` op. class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; // If one of the two dims arrays has size 1, a mapping is created from the one // dimension of the size-1 array to all the dimensions of the other array. For // example for inputs: xDims = [6], yDims = [2, 3] the result in the indices // arrays will be: xIndices = [0], yIndices = [0, 1]. // // An error is returned if the dimension size of the size-1 array is not equal // to the product of all the dimension sizes in the other array, or if neither // of the arrays is size-1. static LogicalResult mapAllDimsToSingleDim(ArrayRef xDims, ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { if (xDims.empty() || yDims.empty()) return failure(); auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || expectedReductionProduct == kUnknownSize) return true; return productReduce(arrayToReduce) == expectedReductionProduct; }; if (xDims.size() == 1) { if (!isValidReduction(xDims[0], yDims)) return failure(); xIndices.assign({0}); yIndices.assign(llvm::to_vector(llvm::seq(0, yDims.size()))); return success(); } else if (yDims.size() == 1) { if (!isValidReduction(yDims[0], xDims)) return failure(); yIndices.assign({0}); xIndices.assign(llvm::to_vector(llvm::seq(0, xDims.size()))); return success(); } return failure(); } // Starting from the beginning of the dims arrays, this helper finds the // smallest set of consecutive dims in each array such that the product of the // dim sizes in the two subsets is equal. The indices arrays are populated // with the indices of the dims arrays that correspond to the subsets found. // // An error is returned if two subsets of dims with total number of elements // equal to each other is not found. static LogicalResult mapStaticallyKnownDims(ArrayRef xDims, ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { if (xDims.empty() || yDims.empty()) return failure(); int64_t xTotalSize = xDims[0]; int64_t yTotalSize = yDims[0]; SmallVector xIndicesResult({0}); SmallVector yIndicesResult({0}); size_t nextXIndex = 1; size_t nextYIndex = 1; while (xTotalSize != yTotalSize) { if (xTotalSize < yTotalSize) { if (nextXIndex == xDims.size() || xDims[nextXIndex] == kUnknownSize) return failure(); xTotalSize *= xDims[nextXIndex]; xIndicesResult.push_back(nextXIndex++); } else { if (nextYIndex == yDims.size() || yDims[nextYIndex] == kUnknownSize) return failure(); yTotalSize *= yDims[nextYIndex]; yIndicesResult.push_back(nextYIndex++); } } xIndices.assign(std::move(xIndicesResult)); yIndices.assign(std::move(yIndicesResult)); return success(); } // Calculates the size of a dynamic dimension if all other dimensions are // statically known, and rewrites that dynamic dimension with the static size. // // Note: this function assumes that all the dimensions in `inputShape` map to // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { if (inputShape.empty() || outputShape.empty()) return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) return; int64_t inputProduct = productReduce(inputShape); int64_t outputProduct = productReduce(outputShape); if (inputDynamicDimCount == 1) { inputProduct /= kUnknownSize; *llvm::find(inputShape, kUnknownSize) = outputProduct / inputProduct; } else { outputProduct /= kUnknownSize; *llvm::find(outputShape, kUnknownSize) = inputProduct / outputProduct; } } // Gets the shapes of the input and output tensors, making a best-effort // attempt to extract static shape information given the inputs to // `aten.view`. static std::pair, SmallVector> getInputAndOutputShape(Value inputTorchTensor, SmallVector outputSizeTorchInt) { SmallVector inputShape( inputTorchTensor.getType().cast().getSizes()); SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; int64_t outputDimSizeInt; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(outputDimSize, m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { outputShape[outputDim] = inputShape[inputDim]; } else if (matchPattern(outputDimSize, m_TorchConstantInt(&outputDimSizeInt))) { if (outputDimSizeInt != -1) { outputShape[outputDim] = outputDimSizeInt; } } } calculateSingleDynamicSize(inputShape, outputShape); return std::make_pair(inputShape, outputShape); } LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); if (resultRank == 0) { rewriter .replaceOpWithNewOp( op, resultType, input, ArrayRef()) .getResult(); return success(); } if (inputRank == 0) { llvm::SmallVector outshape(resultRank, 1); auto expandTy = RankedTensorType::get(outshape, resultType.getElementType()); Value expand = rewriter.create( op.getLoc(), expandTy, input, ArrayRef()); rewriter.replaceOpWithNewOp(op, resultType, expand); return success(); } // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) { return rewriter.notifyMatchFailure(op, "unimplemented: the target size is " "not constructed from ListConstruct"); } if (llvm::count_if(outputSizeTorchInt, [](Value size) -> bool { int64_t sizeInt; if (matchPattern(size, m_TorchConstantInt(&sizeInt))) return sizeInt == -1; return false; }) > 1) { return rewriter.notifyMatchFailure( op, "at most one element in size list is allowed to be -1"); } auto [inputShape, outputShape] = getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. // TODO: For neither collapsing nor expanding, we could find a intermediate // shape to collapse and then expanded to the target shape. Like [2,3] => // [6] => [3, 2]. // Iterate through the view op size list to do the following: // Mark dims in unchangedDims for size list items where the output dim // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; bool singleDynDimsAreEqual = inputHasOneDynDim && outputHasOneDynDim && productReduce(inputShape) == productReduce(outputShape); SmallVector> unchangedDims; for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } else if (singleDynDimsAreEqual && outputShape[outputDim] == kUnknownSize) { // If the input and output have a single dynamic dimension and the // product of the other dimensions is the same, then we know that the // dynamic dimension is unchanged. inputDim = std::distance(inputShape.begin(), llvm::find(inputShape, kUnknownSize)); unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } } // Mark the end of the input/output shapes unchangedDims.push_back(std::make_pair(inputRank, resultRank)); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds // to an expand or collapse. For example, // // inputAssociations: [[0, 1], [2]] // outputAssociations: [[0], [1, 2, 3]] // // indicates that the first two dims of the input tensor // are collapsed into the first dim of the output, and the // third dim of the input is expanded into the last three dims // of the output. SmallVector inputAssociations; SmallVector outputAssociations; // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or // until (inputRank, resultRank) if there is no such op. Look at the first // dimension of the input and output and collapse the larger one by finding // a minimal set of opposing indices with the same number of elements. If // the number of dims to the next boundary is 1, then we assume all // remaining opposing dims must collapse into it. // 2. For handling of dynamic dimensions, we first assume they are only // split if we can easily compute the correct size. // e.g. [2, -1] -> [2, 3, 4] // This mainly happens at the edges of boundaries. Otherwise we try to match // the dynamic dimension with the one across from it and give up if we can't // reason about how the dimensions are associated. // e.g. [-1, -1] -> [2, 3, 4] // For more information, see description of helper functions used in the // `if-else` cases inside the while loop. int64_t inputDim = 0, outputDim = 0; SmallVector> checkDimPairs; for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { auto inputShapeSlice = MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); auto outputShapeSlice = MutableArrayRef(outputShape) .slice(outputDim, nextUnchangedOutput - outputDim); SmallVector inputSliceIndices; SmallVector outputSliceIndices; // TODO: this can be removed by replacing it with a checkDimEqualHelper // that takes into account the product of all the dimensions being // reduced if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 && outputShapeSlice.size() != 1 && inputShapeSlice[0] == kUnknownSize) { return rewriter.notifyMatchFailure( op, "found ambiguous expand of dynamic input sizes " "(e.g. [-1, -1] -> [-1, -1, -1])"); } if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, inputSliceIndices, outputSliceIndices))) { calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); // Update shape to pass the tensor.expand_shape and // tensor.collapse_shape verifiers. If one of the dimensions of the // tensor being flattened is dynamic, the size of the flattened tensor // must also be dynamic. if (inputShapeSlice.size() == 1 && llvm::count(outputShapeSlice, kUnknownSize) > 0) { inputShapeSlice[0] = kUnknownSize; } else if (outputShapeSlice.size() == 1 && llvm::count(inputShapeSlice, kUnknownSize) > 0) { outputShapeSlice[0] = kUnknownSize; } } else if (succeeded(mapStaticallyKnownDims( inputShapeSlice, outputShapeSlice, inputSliceIndices, outputSliceIndices))) { /// `mapStaticallyKnownDims` maps the smallest number of /// input and output dimensions in the slice statically /// known to have the same number of elements. } else if (inputShapeSlice[0] == kUnknownSize) { // Defer the dynamic shape check to avoid DialectConversion assertion: checkDimPairs.push_back( std::pair(inputDim, outputDim)); inputShape[inputDim] = outputShape[outputDim]; inputSliceIndices.push_back(0); outputSliceIndices.push_back(0); assumedDynamicDimNotSplit = true; } else { return rewriter.notifyMatchFailure( op, "unimplemented: found unhandled case of expansion/collapse " "in `aten.view`"); } inputAssociations.emplace_back(); outputAssociations.emplace_back(); for (int64_t inputSliceIndex : inputSliceIndices) inputAssociations.back().push_back(inputSliceIndex + inputDim); for (int64_t outputSliceIndex : outputSliceIndices) outputAssociations.back().push_back(outputSliceIndex + outputDim); inputDim = inputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1; } // Handle any leading or trailing size-1 dimensions and append the // associations for the dims matching `aten.size.int`. if (nextUnchangedInput != inputRank) { assert(nextUnchangedOutput != resultRank && "`nextUnchangedInput` and `nextUnchangedOutput` should equal " "the respective input and output rank at the same time"); inputAssociations.emplace_back(); outputAssociations.emplace_back(); } while (inputDim <= nextUnchangedInput && inputDim < inputRank) { if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only collapsing of static size-1 into " "unchanged dim supported"); } inputAssociations.back().push_back(inputDim++); } while (outputDim <= nextUnchangedOutput && outputDim < resultRank) { if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only expanding of static size-1 out of " "unchanged dim supported"); } outputAssociations.back().push_back(outputDim++); } } SmallVector inputSize = getTensorSizes(rewriter, loc, input); SmallVector outputSizeInt = getTypeConvertedValues( rewriter, loc, typeConverter, outputSizeTorchInt); if (resultRank != (int64_t)outputSizeInt.size()) { return rewriter.notifyMatchFailure( op, "desired size list length mismatches with the result type rank"); } for (auto [inputDim, outputDim] : checkDimPairs) { checkDimEqualHelper(rewriter, loc, inputSize[inputDim], outputSizeInt[outputDim]); } auto cast = [&](Location loc, Type t, Value v) -> Value { return rewriter.createOrFold(loc, t, v); }; // Check if the shapes already match up to dynamic sizes. If so, we can just // cast as the result type because the previous loop sets up the necessary // dim checks in case of dynamic sizes. if (llvm::all_of( inputAssociations, [](ReassociationIndices indices) { return indices.size() == 1; }) && llvm::all_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() == 1; })) { auto castResult = cast(loc, resultType, input); rewriter.replaceOp(op, castResult); return success(); } // TODO: audit possibility of sparsity on these tensors Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( makeShapeLLVMCompatible(inputShape), resultType.getElementType()); Value castedInput = cast(loc, adjustedInputType, input); std::optional expandedInput; std::optional collapsedInput; if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { SmallVector intermediateShape; for (auto i : llvm::seq(0, (int)outputAssociations.size())) { int sum = 1; for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) { if (outputShape[outputAssociations[i][j]] < 0) { sum = kUnknownSize; break; } sum *= outputShape[outputAssociations[i][j]]; } intermediateShape.push_back(sum); } // TODO: audit possibility of sparsity on these tensor Type intermediateResultType = RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); expandedInput = rewriter .create(loc, intermediateResultType, castedInput, inputAssociations) .getResult(); } if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { collapsedInput = rewriter .create( loc, adjustedResultType, expandedInput.has_value() ? expandedInput.value() : castedInput, outputAssociations) .getResult(); } Value result = collapsedInput.has_value() ? collapsedInput.value() : expandedInput.value(); auto castResult = cast(loc, resultType, result); rewriter.replaceOp(op, castResult); return success(); } }; } // namespace namespace { class ConvertAtenSqueezeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); auto inputShape = inputType.getShape(); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); auto resultShape = resultType.getShape(); int64_t resultRank = resultType.getRank(); if (inputRank == 0) { return rewriter.notifyMatchFailure( op, "zero input rank should have been handled by the folder"); } // No change in rank so we just cast to the output type: if (inputRank == resultRank) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } // In case the operand tensor type is statically shaped with all dimensions // being unit extent, it will be collapsed to a 0-D tensor. if (resultRank == 0) { SmallVector reassociation; rewriter.replaceOpWithNewOp( op, resultType, input, reassociation); return success(); } SmallVector reassociation(resultRank); // First dimensions are guaranteed to match to eachother: int64_t i = 0; int64_t j = 0; for (i = 0; i < inputRank && j < resultRank; i++) { reassociation[j].push_back(i); j = inputShape[i] == resultShape[j] ? j + 1 : j; } // Squeeze in the remaining 1s: for (; i < inputRank; ++i) { if (inputShape[i] != 1) return rewriter.notifyMatchFailure(op, "non-unary dim cannot be squeezed"); reassociation.back().push_back(i); } // Make sure that result type rank is compatible with the squeezed size: if (j != resultRank) return rewriter.notifyMatchFailure( op, "expected output size mismatches with the result type rank"); rewriter.replaceOpWithNewOp(op, resultType, input, reassociation); return success(); } }; } // namespace namespace { class ConvertAtenSqueezeDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); if (inputRank == 0) { return rewriter.notifyMatchFailure( op, "zero input rank should have been handled by the folder"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); // TODO: Handle the case where the dim(th) dimension is dynamic. if (inputType.isDynamicDim(dim)) { return rewriter.notifyMatchFailure( op, "unimplemented: dim(th) dimension is not expected to be dynamic"); } const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); // If the dim(th) dimension of operand tensor type is not statically unit, // `aten.squeeze` will behave as an identity operation. if (inputType.getDimSize(dim) != 1) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } SmallVector reassociationMap(resultRank); bool alreadyCrossedSqueezedDim = false; for (int i = 0; i != resultRank; i++) { if (alreadyCrossedSqueezedDim) { reassociationMap[i].push_back(i + 1); } else { reassociationMap[i].push_back(i); if (dim != 0 && i != dim - 1) continue; alreadyCrossedSqueezedDim = true; if (dim == 0) reassociationMap[0].push_back(1); if (i == dim - 1) reassociationMap[i].push_back(dim); } } // Note: In case the operand tensor type is of unit rank and is statically // shaped with unit dimension, the `reassociationMap` will be empty and the // input will be collapsed to a 0-D tensor. rewriter.replaceOpWithNewOp(op, resultType, input, reassociationMap); return success(); } }; } // namespace namespace { class ConvertAtenUnsqueezeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenUnsqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); auto inputRank = adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector reassociationMap(inputRank); // From the perspective of the reassociation map, the situation of // unsqueezing before or after the last dimension is symmetrical. // Normalize it to the "before" case. // The 0 case is special here, since there is no last dimension to insert // before -- we simply rely on the loop below iterating 0 times. if (dim == inputRank && inputRank != 0) dim = inputRank - 1; bool alreadyCrossedExpandedDim = false; for (int i = 0; i != inputRank; i++) { if (alreadyCrossedExpandedDim) { reassociationMap[i].push_back(i + 1); } else { reassociationMap[i].push_back(i); if (i == dim) { reassociationMap[i].push_back(i + 1); alreadyCrossedExpandedDim = true; } } } auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), reassociationMap); return success(); } }; } // namespace namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t dim0; if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) return rewriter.notifyMatchFailure(op, "dim0 must be constant"); int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) return rewriter.notifyMatchFailure(op, "dim1 must be constant"); auto inVector = adaptor.getSelf(); auto inType = inVector.getType().cast(); auto inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); auto elementType = inType.getElementType(); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) return rewriter.notifyMatchFailure(op, "dim0 out of range"); dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range"); auto loc = op.getLoc(); SmallVector outputDims; for (auto i = 0; i < inputRank; i++) outputDims.push_back(getDimOp(rewriter, loc, adaptor.getSelf(), i)); std::swap(outputDims[dim0], outputDims[dim1]); Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); SmallVector idExprs; SmallVector swapExprs; for (auto i = 0; i < inputRank; i++) idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); for (auto i = 0; i < inputRank; i++) { if (i == dim0) swapExprs.push_back(idExprs[dim1]); else if (i == dim1) swapExprs.push_back(idExprs[dim0]); else swapExprs.push_back(idExprs[i]); } SmallVector indexingMaps = { AffineMap::get(inputRank, 0, idExprs, op.getContext()), AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); auto transpose = rewriter .create( loc, outVector.getType(), inVector, outVector, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } }; } // namespace namespace { class ConvertAtenPermuteOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); SmallVector dimensions; if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dimensions))) return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); Value inVector = adaptor.getSelf(); auto inType = inVector.getType().cast(); int64_t inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = inType.getElementType(); // Check if the dimensions are a valid constants. int64_t numDimensions = dimensions.size(); if (inputRank != numDimensions) return rewriter.notifyMatchFailure( op, "size of `dims` must be equal to the rank of the input"); for (unsigned i = 0; i < numDimensions; i++) { if (dimensions[i] < 0) dimensions[i] = toPositiveDim(dimensions[i], inputRank); if (!isValidDim(dimensions[i], inputRank)) return rewriter.notifyMatchFailure(op, "dimension out of range"); } Location loc = op.getLoc(); SmallVector outputDims; for (unsigned i = 0; i < inputRank; i++) outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i])); Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); SmallVector idExprs; SmallVector swapExprs; for (unsigned i = 0; i < inputRank; i++) idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); for (unsigned i = 0; i < inputRank; i++) swapExprs.push_back(idExprs[dimensions[i]]); AffineMap inputMap = AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext()); SmallVector indexingMaps{inputMap, outputMap}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); auto transpose = rewriter .create( loc, outVector.getType(), inVector, outVector, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } }; } // namespace namespace { class ConvertAtenSliceTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); SmallVector resultShape; SmallVector offsets; SmallVector strides; if (failed(prepareArgumentsForSlicingOp( op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } Value result = rewriter.create( loc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenCatOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); // Collect all the tensors to be concatenated. auto tensorList = op.getTensors(); SmallVector tensorsTorchType; if (!getListConstructElements(tensorList, tensorsTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); auto tensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); int rank = newResultType.getRank(); Value dimValue = op.getDim(); int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); dim = toPositiveDim(dim, rank); if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { auto inputType = cast(tensors[i].getType()); if (inputType.getElementType() != outElemType) { tensors[i] = torch_to_linalg::convertTensorToElementType( rewriter, loc, tensors[i], outElemType); } } llvm::SmallVector filteredTensors; for (auto tensor : tensors) { auto inputType = cast(tensor.getType()); if (inputType.getDimSize(dim) != 0) { filteredTensors.push_back(tensor); } } rewriter.replaceOpWithNewOp(op, newResultType, dim, filteredTensors); return success(); } }; } // namespace namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value self = adaptor.getSelf(); SmallVector inShape; if (!getListConstructElements(adaptor.getSize(), inShape)) { return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } // For dynamic input dimension we need to use the `broadcastToShape` // which in this case is `inShapeConverted` because this shape will yield // us the dimension size of the output. SmallVector useBroadcastToShape; int64_t inputRank = self.getType().cast().getRank(); for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; ++i) { int64_t dim; if (matchPattern(inShape[i], m_TorchConstantInt(&dim))) { if (dim < 0) { useBroadcastToShape.push_back(false); } else { useBroadcastToShape.push_back(true); } } else { // Note: Dynamic -1 (inferred) broadcast shapes are unimplemented. useBroadcastToShape.push_back(true); } } SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); auto newResultType = getTypeConverter()->convertType(op.getType()).cast(); Value result; if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, self, inShapeConverted, newResultType, result, useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class ConvertAtenContiguousOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertAtenCopyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value self = adaptor.getSelf(); Value src = adaptor.getSrc(); RankedTensorType selfType = self.getType().cast(); // The non_blocking should be a constant `False`. bool nonBlocking; if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) { return rewriter.notifyMatchFailure( op, "unimplemented: non_blocking must be a constant"); } else if (nonBlocking) { return rewriter.notifyMatchFailure( op, "unimplemented: non_blocking is expected to be false"); } // The size of the src tensor can be different from the self but should be // broadcastable. Therefore, broadcasting the src tensor to match the size // of the self tensor. SmallVector selfSizes = getTensorSizes(rewriter, loc, self); for (unsigned i = 0; i < selfSizes.size(); i++) selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]); Value broadcastedSrc; if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, src, selfSizes, selfType, broadcastedSrc))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(), rewriter.getContext()); SmallVector iteratorTypes( selfType.getRank(), utils::IteratorType::parallel); Value result = rewriter .create( loc, /*resultType=*/selfType, /*inputs=*/broadcastedSrc, /*outputs=*/self, /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { Value result = args[0]; if (args[0].getType() != args[1].getType()) { result = convertScalarToDtype(b, loc, args[0], args[1].getType()); } b.create(loc, result); }) ->getResult(0); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenSliceScatterOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); SmallVector resultShape; SmallVector offsets; SmallVector strides; if (failed(prepareArgumentsForSlicingOp( op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } Value src = adaptor.getSrc(); auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); // TODO: audit possibility of sparsity on these tensor auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = rewriter.create(loc, abstractSrcType, src); Value result = rewriter.create( loc, abstractSrc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenViewAsComplexOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenViewAsComplexOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); MLIRContext *context = rewriter.getContext(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op.getType()).cast(); auto elementType = resultType.getElementType(); SmallVector resultShape; for (int64_t i = 0; i < resultType.getRank(); i++) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(resultShape), elementType); SmallVector outputExpr; for (unsigned i = 0; i < resultType.getRank(); i++) { outputExpr.push_back(getAffineDimExpr(i, context)); } Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); Value constantOne = getConstant(rewriter, loc, 1, mlir::IndexType::get(context)); AffineMap outputMap = AffineMap::get(resultType.getRank(), 0, outputExpr, op->getContext()); SmallVector indexingMaps{outputMap}; SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); auto complexVar = rewriter .create( loc, outTensor.getType(), ValueRange{}, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector indicesZero; SmallVector indicesOne; for (int i = 0; i < resultType.getRank(); i++) { indicesZero.push_back(b.create(loc, i)); indicesOne.push_back(b.create(loc, i)); } indicesZero.push_back(constantZero); indicesOne.push_back(constantOne); Value realVal = b.create(loc, input, indicesZero); Value imagVal = b.create(loc, input, indicesOne); Value complexVal = b.create( loc, elementType, realVal, imagVal); b.create(loc, complexVal); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, complexVar); return success(); } }; } // namespace namespace { class ConvertAtenViewAsRealOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenViewAsRealOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); MLIRContext *context = rewriter.getContext(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op.getType()).cast(); RankedTensorType inputType = input.getType().cast(); auto inputElementType = getElementTypeOrSelf(input.getType()); if (!inputElementType.isa()) { return op.emitError("only ComplexType is allowed as input type"); } Type elementType = resultType.getElementType(); // returned real tensor has a size increase, where the last dim has size 2 SmallVector resultShape = tensor::getMixedSizes(rewriter, loc, input); resultShape.push_back( rewriter.createOrFold(loc, 2)); Value outTensor = rewriter.create(loc, resultShape, elementType); SmallVector inputExpr; for (unsigned i = 0; i < resultType.getRank() - 1; i++) { inputExpr.push_back(getAffineDimExpr(i, context)); } AffineMap inputMap = AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); inputExpr.push_back(getAffineDimExpr(resultType.getRank() - 1, context)); AffineMap outputMap = AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); SmallVector indexingMaps{inputMap, outputMap}; SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); auto realVar = rewriter .create( loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value realVal = b.create(loc, elementType, args[0]); Value imagVal = b.create(loc, elementType, args[0]); Value lastIndex = b.create(loc, inputType.getRank()); Value cmpResult = b.create( loc, arith::CmpIPredicate::eq, lastIndex, constantZero); Value yieldValue = b.create( loc, cmpResult, realVal, imagVal); b.create(loc, yieldValue); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, realVar); return success(); } }; } // namespace namespace { class ConvertAtenDiagonalOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t offset; if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) return rewriter.notifyMatchFailure(op, "offset must be constant"); int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) return rewriter.notifyMatchFailure(op, "dim1 must be constant"); int64_t dim2; if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) return rewriter.notifyMatchFailure(op, "dim2 must be constant"); Value inputMatrix = adaptor.getSelf(); RankedTensorType inputType = inputMatrix.getType().cast(); int64_t inputRank = inputType.getRank(); if (inputRank < 2) return rewriter.notifyMatchFailure( op, "input must have at least two dimensions"); int64_t outputRank = inputRank - 1; dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range"); dim2 = toPositiveDim(dim2, inputRank); if (!isValidDim(dim2, inputRank)) return rewriter.notifyMatchFailure(op, "dim2 out of range"); if (dim1 == dim2) return rewriter.notifyMatchFailure( op, "diagonal dimensions cannot be identical"); Type elementType = inputType.getElementType(); RankedTensorType outputType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Location loc = op.getLoc(); Value dim1Size, dim2Size; dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1); dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2); // compute the length of the diagonal with possible offset // if the offset is very large or very small, diagSize=0 and an empty tensor // is returned Value indexZero = rewriter.create(loc, 0); Value indexMinusOne = rewriter.create(loc, -1); Value indexOffset = rewriter.create(loc, offset); Value offsetIsNegative = rewriter.create( loc, arith::CmpIPredicate::sle, indexOffset, indexZero); Value sizeForNegativeOffset = rewriter.create( loc, rewriter.create( loc, rewriter.create(loc, dim1Size, indexOffset), dim2Size), indexZero); Value sizeForPositiveOffset = rewriter.create( loc, rewriter.create( loc, rewriter.create(loc, dim2Size, indexOffset), dim1Size), indexZero); Value diagSize = rewriter.create( loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset); // depending on its sign, the offset affects only the row or column indices // of the diagonal Value diagStart1 = rewriter.create( loc, offsetIsNegative, rewriter.create(loc, indexOffset, indexMinusOne), indexZero); Value diagStart2 = rewriter.create(loc, offsetIsNegative, indexZero, indexOffset); SmallVector outputDims; for (auto i = 0; i < inputRank; i++) { if (!(i == dim1 || i == dim2)) outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i)); } outputDims.push_back(diagSize); Value outputMatrix = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); SmallVector indexingMaps = { AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())}; SmallVector iteratorTypes( outputRank, utils::IteratorType::parallel); auto diagonal = rewriter .create( loc, outputMatrix.getType(), ValueRange{}, outputMatrix, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector diagIndices; Value indexOnDiag = b.create(loc, outputRank - 1); Value dim1Index = b.create(loc, indexOnDiag, diagStart1); Value dim2Index = b.create(loc, indexOnDiag, diagStart2); // specify at which input indices the diagonal values are // extracted for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) { if (indIn == dim1) diagIndices.push_back(dim1Index); else if (indIn == dim2) diagIndices.push_back(dim2Index); else { diagIndices.push_back( b.create(loc, indOut)); indOut++; } } Value diagElt = b.create( loc, elementType, inputMatrix, diagIndices); b.create(loc, diagElt); }) .getResult(0); rewriter.replaceOpWithNewOp(op, outputType, diagonal); return success(); } }; } // namespace namespace { class ConvertAtenDiagEmbedOp : public OpConversionPattern { static SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2) { auto inputType = tensor.getType().cast(); auto inputRank = inputType.getRank(); // output tensor always has 1 extra dimension auto resultRank = inputRank + 1; // regardless of offset sign, output tensor is same Value constOffset = b.create(loc, offset); Value absOffset = b.create(loc, constOffset); // diagonal size is determined by last input dimension auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); Value diagDim = b.create(loc, lastInputDim, absOffset); // output shape has same dimensions as input // except for the diagonal dimensions int input_dim_idx = 0; SmallVector resultShape; for (unsigned int i = 0; i < resultRank; i++) { if (i == dim1 || i == dim2) resultShape.push_back(diagDim); else resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++)); } return resultShape; } public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); auto resultRank = inputRank + 1; int64_t offset; if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) return rewriter.notifyMatchFailure(op, "offset is not constant"); int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) return rewriter.notifyMatchFailure(op, "dim1 is not constant"); dim1 = toPositiveDim(dim1, resultRank); if (!isValidDim(dim1, resultRank)) return rewriter.notifyMatchFailure( op, "dim1 can only be in closed range [" + std::to_string(-resultRank) + "," + std::to_string(resultRank - 1) + "]"); int64_t dim2; if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) return rewriter.notifyMatchFailure(op, "dim2 is not constant"); dim2 = toPositiveDim(dim2, resultRank); if (!isValidDim(dim2, resultRank)) return rewriter.notifyMatchFailure( op, "dim2 can only be in closed range [" + std::to_string(-resultRank) + "," + std::to_string(resultRank - 1) + "]"); if (dim1 == dim2) return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal"); // add linalg.fill Type resultElemType = inputType.getElementType(); auto resultShape = getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2); Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, resultElemType); // add linalg.generic with diagonal access pattern affine indexing maps SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(resultRank), }; SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); Value resultTensor = rewriter .create( loc, zeroTensor.getType(), ValueRange{}, zeroTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim1Index = b.create(loc, dim1); Value dim2Index = b.create(loc, dim2); // to pick right element from input, first add all dimensions // except last one, then last will be either dim1 or dim2 // depending upon lower or upper diagonal defined by offset // sign SmallVector inputIndices; for (unsigned int i = 0; i < resultRank; i++) { if (i != dim1 && i != dim2) { inputIndices.push_back(b.create(loc, i)); } } // adjust output diagonal indices and last input Index based // on offset Value dim1IdxAdjusted; Value dim2IdxAdjusted; if (offset < 0) { Value absOffset = b.create(loc, -offset); dim1IdxAdjusted = dim1Index; dim2IdxAdjusted = b.create(loc, dim2Index, absOffset); inputIndices.push_back( b.create(loc, dim2)); } else { Value constOffset = b.create(loc, offset); dim1IdxAdjusted = b.create(loc, dim1Index, constOffset); dim2IdxAdjusted = dim2Index; inputIndices.push_back( b.create(loc, dim1)); } Value isDiagonal = b.create(loc, arith::CmpIPredicate::eq, dim1IdxAdjusted, dim2IdxAdjusted); Value inputElem = b.create( loc, resultElemType, input, inputIndices); Value result = rewriter.create( loc, isDiagonal, inputElem, args[0]); b.create(loc, result); }) .getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, resultTensor); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }