//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static int64_t productReduce(ArrayRef a) { return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); } template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); RankedTensorType inputType = input.getType().template cast(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); int64_t inputRank = inputType.getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value dimSize = inputShape[dim]; Value torchTypeStart = op.getStart(); Value torchTypeEnd = op.getEnd(); Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); if (torchTypeStart.getType().isa() || torchTypeEnd.getType().isa()) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { if (!op.getStep().getType().template isa()) return op->emitError("unimplemented: step is not constant"); step = 1; } Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, dimSize, dimSize); // end >= start ? end : start Value endSgeStart = rewriter.create( loc, arith::CmpIPredicate::sge, end, start); end = rewriter.create(loc, endSgeStart, end, start); Value stepIndex = rewriter.create(loc, step); // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); Value len = rewriter.create(loc, end, start); Value resultSize = rewriter.create(loc, len, stepIndex); resultSize = rewriter.create(loc, resultSize, one); resultSize = rewriter.create(loc, resultSize, stepIndex); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; strides[dim] = rewriter.create(loc, strides[dim], stepIndex); return success(); } // Example: // input = tensor([[[0., 1., 2., 3.], // [4., 5., 6., 7.]]]) // torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1 // tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], // [7., 6., 5., 4., 5., 6., 7., 6.]]]) // Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension // Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension. // The last dimension of the result tensor should be last dimension of input tensor + // left padding size + right padding size. INitialize result tensor to all zeros // b) Setup affine map to take slice from input tensor of size left padding starting from // second column onwards as first column is reflection boundary // c) Reflect the affine map to have resultant slice reflected // d) Take the slice and write from begining in result tensor // e) write the original tensor next into result tensor // f) Setup affine map to take slice from input tensor of right padding size ending // at second last column as last column is reflection boundary for right padding // g) Reflect the affine map to have resultant slice reflected // h) Take the slice and write from left padding size + orignal tensor last dim size // into result tensor // Uses the ideas/code used for AtenReflectionPad2dOp namespace { class ConvertAtenReflectionPad1dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( op, "only constant int padding range is supported"); MLIRContext *context = rewriter.getContext(); Location loc = op.getLoc(); // Lambda Unitility Functions // Create an Integer expression of x + y auto createIAdd = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; // Create an integer expression of x - y auto createISub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2}; Value input = adaptor.getSelf(); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); auto inputType = llvm::cast(input.getType()); auto outputType = llvm::cast(getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); int64_t lastDim = numDims - 1; SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4 Value tileWidth[3], extractOffset[3], insertOffset[3]; tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); tileWidth[PAD_CENTER] = lastDimSize; extractOffset[PAD_LEFT] = one; // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right // lasDimSize - (tileWidth[PAD_RIGHT] + one) extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); extractOffset[PAD_CENTER] = zero; insertOffset[PAD_LEFT] = zero; insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; SmallVector resultShape{inputShape}; // Result's last dimension will have shape lastDimSize + left padding size + right padding size resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor // for which the corresponding dimension has a statically known size auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { AffineExpr d = map.getResult(i); return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3 }; SmallVector iteratorTypes{numDims, utils::IteratorType::parallel}; auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto addTileToResult = [&](PadLocation padPosition) { // Create the tile by extracting a slice from the input tensor. SmallVector extractShape{inputShape}; extractShape[lastDim] = tileWidth[padPosition]; SmallVector extractOffsets(numDims, zero); extractOffsets[lastDim] = extractOffset[padPosition]; Value tile = rewriter.create( loc, input, extractOffsets, extractShape, allOneStrides); auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); // Setup the affine map function to resverse the tile along the horizontal for left and right slices if(padPosition < PAD_CENTER) { inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); // Take reflected slice as per inputMap tile = rewriter.create(loc, llvm::cast(tile.getType()), tile, tile, ArrayRef({inputMap, idMap}), iteratorTypes, [](OpBuilder &b, Location nestedLoc, ValueRange args) { b.create(nestedLoc, args[0]); }).getResult(0); } // Insert the tile in the resultTensor SmallVector insertOffsets(numDims, zero); insertOffsets[lastDim] = insertOffset[padPosition]; resultTensor = rewriter.create(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); }; if(padInts[PAD_LEFT] > 0) addTileToResult(PAD_LEFT); if(padInts[PAD_RIGHT] > 0) addTileToResult(PAD_RIGHT); addTileToResult(PAD_CENTER); rewriter.replaceOpWithNewOp(op, outputType, resultTensor); return success(); } }; } namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t startDim; if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&startDim))) return rewriter.notifyMatchFailure(op, "start_dim must be constant"); int64_t endDim; if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim))) return rewriter.notifyMatchFailure(op, "end_dim must be constant"); auto type = adaptor.getSelf().getType().cast(); auto inputRank = type.getRank(); if (inputRank == 1) { // If input rank is equal to 1, then there's no scope for flattening the // input tensor. rewriter.replaceOp(op, adaptor.getSelf()); return success(); } auto resultType = getTypeConverter()->convertType(op.getType()).cast(); if (startDim < 0) startDim += inputRank; if (endDim < 0) endDim += inputRank; if (inputRank == 0) { SmallVector reassociation; if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0)) return rewriter.notifyMatchFailure( op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0"); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), reassociation); return success(); } if (startDim < 0 || startDim >= inputRank || endDim < 0 || endDim >= inputRank || startDim > endDim) return rewriter.notifyMatchFailure( op, "statically invalid flattening dim range"); SmallVector reassociation(resultType.getRank()); int j = 0; for (auto i : llvm::seq(0, inputRank)) { reassociation[j].push_back(i); if (i < startDim || i >= endDim) j++; } Value collapsedTensor = rewriter.create( op->getLoc(), adaptor.getSelf(), reassociation); rewriter.replaceOpWithNewOp(op, resultType, collapsedTensor); return success(); } }; } // namespace namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to /// one `linalg.TensorExpandShape` op for all expanded dimensions and one /// `linalg.TensorCollapseShape` op for all collapsed dimensions. Cases where /// there is neither an expand or collapse of dimensions (e.g. [2, 3] -> [3, 2]) /// is not handled. Additionally, certain dynamic dimension cases rely on naive /// assumptions or aren't supported. /// TODO: Handle all the other cases of `aten.View` op. class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; // If one of the two dims arrays has size 1, a mapping is created from the one // dimension of the size-1 array to all the dimensions of the other array. For // example for inputs: xDims = [6], yDims = [2, 3] the result in the indices // arrays will be: xIndices = [0], yIndices = [0, 1]. // // An error is returned if the dimension size of the size-1 array is not equal // to the product of all the dimension sizes in the other array, or if neither // of the arrays is size-1. static LogicalResult mapAllDimsToSingleDim(ArrayRef xDims, ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { if (xDims.empty() || yDims.empty()) return failure(); auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || expectedReductionProduct == kUnknownSize) return true; return productReduce(arrayToReduce) == expectedReductionProduct; }; if (xDims.size() == 1) { if (!isValidReduction(xDims[0], yDims)) return failure(); xIndices.assign({0}); yIndices.assign(llvm::to_vector(llvm::seq(0, yDims.size()))); return success(); } else if (yDims.size() == 1) { if (!isValidReduction(yDims[0], xDims)) return failure(); yIndices.assign({0}); xIndices.assign(llvm::to_vector(llvm::seq(0, xDims.size()))); return success(); } return failure(); } // Starting from the beginning of the dims arrays, this helper finds the // smallest set of consecutive dims in each array such that the product of the // dim sizes in the two subsets is equal. The indices arrays are populated // with the indices of the dims arrays that correspond to the subsets found. // // An error is returned if two subsets of dims with total number of elements // equal to each other is not found. static LogicalResult mapStaticallyKnownDims(ArrayRef xDims, ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { if (xDims.empty() || yDims.empty()) return failure(); int64_t xTotalSize = xDims[0]; int64_t yTotalSize = yDims[0]; SmallVector xIndicesResult({0}); SmallVector yIndicesResult({0}); size_t nextXIndex = 1; size_t nextYIndex = 1; while (xTotalSize != yTotalSize) { if (xTotalSize < yTotalSize) { if (nextXIndex == xDims.size() || xDims[nextXIndex] == kUnknownSize) return failure(); xTotalSize *= xDims[nextXIndex]; xIndicesResult.push_back(nextXIndex++); } else { if (nextYIndex == yDims.size() || yDims[nextYIndex] == kUnknownSize) return failure(); yTotalSize *= yDims[nextYIndex]; yIndicesResult.push_back(nextYIndex++); } } xIndices.assign(std::move(xIndicesResult)); yIndices.assign(std::move(yIndicesResult)); return success(); } // Calculates the size of a dynamic dimension if all other dimensions are // statically known, and rewrites that dynamic dimension with the static size. // // Note: this function assumes that all the dimensions in `inputShape` map to // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { if (inputShape.empty() || outputShape.empty()) return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) return; int64_t inputProduct = productReduce(inputShape); int64_t outputProduct = productReduce(outputShape); if (inputDynamicDimCount == 1) { inputProduct /= kUnknownSize; *llvm::find(inputShape, kUnknownSize) = outputProduct / inputProduct; } else { outputProduct /= kUnknownSize; *llvm::find(outputShape, kUnknownSize) = inputProduct / outputProduct; } } // Gets the shapes of the input and output tensors, making a best-effort // attempt to extract static shape information given the inputs to // `aten.view`. static std::pair, SmallVector> getInputAndOutputShape(Value inputTorchTensor, SmallVector outputSizeTorchInt) { SmallVector inputShape( inputTorchTensor.getType().cast().getSizes()); SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; int64_t outputDimSizeInt; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(outputDimSize, m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { outputShape[outputDim] = inputShape[inputDim]; } else if (matchPattern(outputDimSize, m_TorchConstantInt(&outputDimSizeInt))) { if (outputDimSizeInt != -1) { outputShape[outputDim] = outputDimSizeInt; } } } calculateSingleDynamicSize(inputShape, outputShape); return std::make_pair(inputShape, outputShape); } LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); SmallVector inputSize = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); if (resultRank == 0) return rewriter.notifyMatchFailure(op, "result shape of rank 0 is invalid"); // TODO: add support for case inputRank 0 expanded to size 1 if (inputRank == 0) return rewriter.notifyMatchFailure( op, "unimplemented: input rank 0 is not supported"); // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) { return rewriter.notifyMatchFailure(op, "unimplemented: the target size is " "not constructed from ListConstruct"); } if (llvm::count_if(outputSizeTorchInt, [](Value size) -> bool { int64_t sizeInt; if (matchPattern(size, m_TorchConstantInt(&sizeInt))) return sizeInt == -1; return false; }) > 1) { return rewriter.notifyMatchFailure( op, "at most one element in size list is allowed to be -1"); } SmallVector outputSizeInt = getTypeConvertedValues( rewriter, loc, typeConverter, outputSizeTorchInt); if (resultRank != (int64_t)outputSizeInt.size()) { return rewriter.notifyMatchFailure( op, "desired size list length mismatches with the result type rank"); } auto [inputShape, outputShape] = getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. // TODO: For neither collapsing nor expanding, we could find a intermediate // shape to collapse and then expanded to the target shape. Like [2,3] => // [6] => [3, 2]. // Iterate through the view op size list to do the following: // Mark dims in unchangedDims for size list items where the output dim // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; bool singleDynDimsAreEqual = inputHasOneDynDim && outputHasOneDynDim && productReduce(inputShape) == productReduce(outputShape); SmallVector> unchangedDims; for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(outputDimSize, m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } else if (singleDynDimsAreEqual && outputShape[outputDim] == kUnknownSize) { // If the input and output have a single dynamic dimension and the // product of the other dimensions is the same, then we know that the // dynamic dimension is unchanged. inputDim = std::distance(inputShape.begin(), llvm::find(inputShape, kUnknownSize)); unchangedDims.push_back(std::make_pair(inputDim, outputDim)); } } // Mark the end of the input/output shapes unchangedDims.push_back(std::make_pair(inputRank, resultRank)); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds // to an expand or collapse. For example, // // inputAssociations: [[0, 1], [2]] // outputAssociations: [[0], [1, 2, 3]] // // indicates that the first two dims of the input tensor // are collapsed into the first dim of the output, and the // third dim of the input is expanded into the last three dims // of the output. SmallVector inputAssociations; SmallVector outputAssociations; // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or // until (inputRank, resultRank) if there is no such op. Look at the first // dimension of the input and output and collapse the larger one by finding // a minimal set of opposing indices with the same number of elements. If // the number of dims to the next boundary is 1, then we assume all // remaining opposing dims must collapse into it. // 2. For handling of dynamic dimensions, we first assume they are only // split if we can easily compute the correct size. // e.g. [2, -1] -> [2, 3, 4] // This mainly happens at the edges of boundaries. Otherwise we try to match // the dynamic dimension with the one across from it and give up if we can't // reason about how the dimensions are associated. // e.g. [-1, -1] -> [2, 3, 4] // For more information, see description of helper functions used in the // `if-else` cases inside the while loop. int64_t inputDim = 0, outputDim = 0; for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) { // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { auto inputShapeSlice = MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); auto outputShapeSlice = MutableArrayRef(outputShape) .slice(outputDim, nextUnchangedOutput - outputDim); SmallVector inputSliceIndices; SmallVector outputSliceIndices; // TODO: this can be removed by replacing it with a checkDimEqualHelper // that takes into account the product of all the dimensions being // reduced if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 && outputShapeSlice.size() != 1 && inputShapeSlice[0] == kUnknownSize) { return rewriter.notifyMatchFailure( op, "found ambiguous expand of dynamic input sizes " "(e.g. [-1, -1] -> [-1, -1, -1])"); } if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice, inputSliceIndices, outputSliceIndices))) { calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice); // Update shape to pass the tensor.expand_shape and // tensor.collapse_shape verifiers. If one of the dimensions of the // tensor being flattened is dynamic, the size of the flattened tensor // must also be dynamic. if (inputShapeSlice.size() == 1 && llvm::count(outputShapeSlice, kUnknownSize) > 0) { inputShapeSlice[0] = kUnknownSize; } else if (outputShapeSlice.size() == 1 && llvm::count(inputShapeSlice, kUnknownSize) > 0) { outputShapeSlice[0] = kUnknownSize; } } else if (succeeded(mapStaticallyKnownDims( inputShapeSlice, outputShapeSlice, inputSliceIndices, outputSliceIndices))) { /// `mapStaticallyKnownDims` maps the smallest number of /// input and output dimensions in the slice statically /// known to have the same number of elements. } else if (inputShapeSlice[0] == kUnknownSize) { // If the input is dynamic, assume it is not split checkDimEqualHelper(rewriter, loc, inputSize[inputDim], outputSizeInt[outputDim]); // If output dimension is not dynamic, improve static information of // input inputShape[inputDim] = outputShape[outputDim]; inputSliceIndices.push_back(0); outputSliceIndices.push_back(0); assumedDynamicDimNotSplit = true; } else { return rewriter.notifyMatchFailure( op, "unimplemented: found unhandled case of expansion/collapse " "in `aten.view`"); } inputAssociations.emplace_back(); outputAssociations.emplace_back(); for (int64_t inputSliceIndex : inputSliceIndices) inputAssociations.back().push_back(inputSliceIndex + inputDim); for (int64_t outputSliceIndex : outputSliceIndices) outputAssociations.back().push_back(outputSliceIndex + outputDim); inputDim = inputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1; } // Handle any leading or trailing size-1 dimensions and append the // associations for the dims matching `aten.size.int`. if (nextUnchangedInput != inputRank) { assert(nextUnchangedOutput != resultRank && "`nextUnchangedInput` and `nextUnchangedOutput` should equal " "the respective input and output rank at the same time"); inputAssociations.emplace_back(); outputAssociations.emplace_back(); } while (inputDim <= nextUnchangedInput && inputDim < inputRank) { if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only collapsing of static size-1 into " "unchanged dim supported"); } inputAssociations.back().push_back(inputDim++); } while (outputDim <= nextUnchangedOutput && outputDim < resultRank) { if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only expanding of static size-1 out of " "unchanged dim supported"); } outputAssociations.back().push_back(outputDim++); } } auto cast = [&](Location loc, Type t, Value v) -> Value { return rewriter.createOrFold(loc, t, v); }; // Check if the shapes already match up to dynamic sizes. If so, we can just // cast as the result type because the previous loop sets up the necessary // dim checks in case of dynamic sizes. if (llvm::all_of( inputAssociations, [](ReassociationIndices indices) { return indices.size() == 1; }) && llvm::all_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() == 1; })) { auto castResult = cast(loc, resultType, input); rewriter.replaceOp(op, castResult); return success(); } Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( makeShapeLLVMCompatible(inputShape), resultType.getElementType()); Value castedInput = cast(loc, adjustedInputType, input); std::optional expandedInput; std::optional collapsedInput; if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { SmallVector intermediateShape; for (auto i : llvm::seq(0, (int)outputAssociations.size())) { int sum = 1; for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) { if (outputShape[outputAssociations[i][j]] < 0) { sum = kUnknownSize; break; } sum *= outputShape[outputAssociations[i][j]]; } intermediateShape.push_back(sum); } Type intermediateResultType = RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); expandedInput = rewriter .create(loc, intermediateResultType, castedInput, inputAssociations) .getResult(); } if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { collapsedInput = rewriter .create( loc, adjustedResultType, expandedInput.has_value() ? expandedInput.value() : castedInput, outputAssociations) .getResult(); } Value result = collapsedInput.has_value() ? collapsedInput.value() : expandedInput.value(); auto castResult = cast(loc, resultType, result); rewriter.replaceOp(op, castResult); return success(); } }; } // namespace namespace { class ConvertAtenSqueezeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); if (inputRank == 0) { return rewriter.notifyMatchFailure( op, "zero input rank should have been handled by the folder"); } // In case the operand tensor type is statically shaped with all dimensions // being unit extent, it will be collapsed to a 0-D tensor. if (resultRank == 0) { SmallVector reassociation; rewriter.replaceOpWithNewOp( op, resultType, input, reassociation); return success(); } // All the static size-1 dimensions at the beginning(going from higher to // lower dimensions) will be collapsed into the first dynamic or first non // size-1 static dimension. All the other static size-1 dimensions will be // collapsed into its previous dynamic or non size-1 static dimension. SmallVector reassociation(resultRank); bool isSqueezed = false; int64_t headOnesCount = 0; while (headOnesCount < inputRank && inputType.getDimSize(headOnesCount) == 1) { isSqueezed = true; reassociation[0].push_back(headOnesCount++); } Value one = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t j = -1; bool elideDynamicBroadcastDimCheck = isAssumingStrictSymbolicShapes(rewriter); for (auto i : llvm::seq(headOnesCount, inputRank)) { if (inputType.isDynamicDim(i)) { if (!elideDynamicBroadcastDimCheck) { // Make sure that size-1 dynamic dimension does not exist. Value dimSize = getDimOp(rewriter, loc, input, i); Value dimSizeNotOne = rewriter.create( loc, arith::CmpIPredicate::ne, dimSize, one); rewriter.create( loc, dimSizeNotOne, rewriter.getStringAttr( "unimplemented: size 1 dynamic dimension is not supported")); } ++j; } else if (inputType.getDimSize(i) != 1) { ++j; } else { // `isSqueezed` checks if the operand tensor type contains at least one // unit dimension. isSqueezed = true; } if (j == resultRank) break; reassociation[j].push_back(i); } // Make sure that result type rank is compatible with the squeezed size. if (j != resultRank - 1) return rewriter.notifyMatchFailure( op, "expected output size mismatches with the result type rank"); if (isSqueezed) { rewriter.replaceOpWithNewOp( op, resultType, input, reassociation); } else { // If the operand tensor type does not have any unit dimension, // `aten.squeeze` will behave as an identity operation. rewriter.replaceOpWithNewOp(op, resultType, input); } return success(); } }; } // namespace namespace { class ConvertAtenSqueezeDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value input = adaptor.getSelf(); auto inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); if (inputRank == 0) { return rewriter.notifyMatchFailure( op, "zero input rank should have been handled by the folder"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); // TODO: Handle the case where the dim(th) dimension is dynamic. if (inputType.isDynamicDim(dim)) { return rewriter.notifyMatchFailure( op, "unimplemented: dim(th) dimension is not expected to be dynamic"); } const TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); // If the dim(th) dimension of operand tensor type is not statically unit, // `aten.squeeze` will behave as an identity operation. if (inputType.getDimSize(dim) != 1) { rewriter.replaceOpWithNewOp(op, resultType, input); return success(); } SmallVector reassociationMap(resultRank); bool alreadyCrossedSqueezedDim = false; for (int i = 0; i != resultRank; i++) { if (alreadyCrossedSqueezedDim) { reassociationMap[i].push_back(i + 1); } else { reassociationMap[i].push_back(i); if (dim != 0 && i != dim - 1) continue; alreadyCrossedSqueezedDim = true; if (dim == 0) reassociationMap[0].push_back(1); if (i == dim - 1) reassociationMap[i].push_back(dim); } } // Note: In case the operand tensor type is of unit rank and is statically // shaped with unit dimension, the `reassociationMap` will be empty and the // input will be collapsed to a 0-D tensor. rewriter.replaceOpWithNewOp(op, resultType, input, reassociationMap); return success(); } }; } // namespace namespace { class ConvertAtenUnsqueezeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenUnsqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); auto inputRank = adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector reassociationMap(inputRank); // From the perspective of the reassociation map, the situation of // unsqueezing before or after the last dimension is symmetrical. // Normalize it to the "before" case. // The 0 case is special here, since there is no last dimension to insert // before -- we simply rely on the loop below iterating 0 times. if (dim == inputRank && inputRank != 0) dim = inputRank - 1; bool alreadyCrossedExpandedDim = false; for (int i = 0; i != inputRank; i++) { if (alreadyCrossedExpandedDim) { reassociationMap[i].push_back(i + 1); } else { reassociationMap[i].push_back(i); if (i == dim) { reassociationMap[i].push_back(i + 1); alreadyCrossedExpandedDim = true; } } } auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), reassociationMap); return success(); } }; } // namespace namespace { class ConvertAtenTransposeIntOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); int64_t dim0; if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) return rewriter.notifyMatchFailure(op, "dim0 must be constant"); int64_t dim1; if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) return rewriter.notifyMatchFailure(op, "dim1 must be constant"); auto inVector = adaptor.getSelf(); auto inType = inVector.getType().cast(); auto inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); auto elementType = inType.getElementType(); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) return rewriter.notifyMatchFailure(op, "dim0 out of range"); dim1 = toPositiveDim(dim1, inputRank); if (!isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range"); auto loc = op.getLoc(); SmallVector outputDims; for (auto i = 0; i < inputRank; i++) outputDims.push_back(getDimOp(rewriter, loc, adaptor.getSelf(), i)); std::swap(outputDims[dim0], outputDims[dim1]); Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); SmallVector idExprs; SmallVector swapExprs; for (auto i = 0; i < inputRank; i++) idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); for (auto i = 0; i < inputRank; i++) { if (i == dim0) swapExprs.push_back(idExprs[dim1]); else if (i == dim1) swapExprs.push_back(idExprs[dim0]); else swapExprs.push_back(idExprs[i]); } SmallVector indexingMaps = { AffineMap::get(inputRank, 0, idExprs, op.getContext()), AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); auto transpose = rewriter .create( loc, outVector.getType(), inVector, outVector, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } }; } // namespace namespace { class ConvertAtenPermuteOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); SmallVector dimensions; if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dimensions))) return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); Value inVector = adaptor.getSelf(); auto inType = inVector.getType().cast(); int64_t inputRank = inType.getRank(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = inType.getElementType(); // Check if the dimensions are a valid constants. int64_t numDimensions = dimensions.size(); if (inputRank != numDimensions) return rewriter.notifyMatchFailure( op, "size of `dims` must be equal to the rank of the input"); for (unsigned i = 0; i < numDimensions; i++) { if (dimensions[i] < 0) dimensions[i] = toPositiveDim(dimensions[i], inputRank); if (!isValidDim(dimensions[i], inputRank)) return rewriter.notifyMatchFailure(op, "dimension out of range"); } Location loc = op.getLoc(); SmallVector outputDims; for (unsigned i = 0; i < inputRank; i++) outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i])); Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); SmallVector idExprs; SmallVector swapExprs; for (unsigned i = 0; i < inputRank; i++) idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); for (unsigned i = 0; i < inputRank; i++) swapExprs.push_back(idExprs[dimensions[i]]); AffineMap inputMap = AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext()); SmallVector indexingMaps{inputMap, outputMap}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); auto transpose = rewriter .create( loc, outVector.getType(), inVector, outVector, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } }; } // namespace namespace { class ConvertAtenSliceTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); SmallVector resultShape; SmallVector offsets; SmallVector strides; if (failed(prepareArgumentsForSlicingOp( op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } Value result = rewriter.create( loc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenCatOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); // Collect all the tensors to be concatenated. auto tensorList = op.getTensors(); SmallVector tensorsTorchType; if (!getListConstructElements(tensorList, tensorsTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); auto tensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { auto inputType = cast(tensors[i].getType()); if (inputType.getElementType() != outElemType) { tensors[i] = torch_to_linalg::convertTensorToElementType( rewriter, loc, tensors[i], outElemType); } } int rank = newResultType.getRank(); Value dimValue = op.getDim(); int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); dim = toPositiveDim(dim, rank); if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); rewriter.replaceOpWithNewOp(op, newResultType, dim, tensors); return success(); } }; } // namespace namespace { class ConvertAtenBroadcastToOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value self = adaptor.getSelf(); SmallVector inShape; if (!getListConstructElements(adaptor.getSize(), inShape)) { return rewriter.notifyMatchFailure( op, "unimplemented: the size list is not from list construct"); } // For dynamic input dimension we need to use the `broadcastToShape` // which in this case is `inShapeConverted` because this shape will yield // us the dimension size of the output. SmallVector useBroadcastToShape; int64_t inputRank = self.getType().cast().getRank(); for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; ++i) { int64_t dim; if (matchPattern(inShape[i], m_TorchConstantInt(&dim))) { if (dim < 0) { useBroadcastToShape.push_back(false); } else { useBroadcastToShape.push_back(true); } } else { // Note: Dynamic -1 (inferred) broadcast shapes are unimplemented. useBroadcastToShape.push_back(true); } } SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); auto newResultType = getTypeConverter()->convertType(op.getType()).cast(); Value result; if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, self, inShapeConverted, newResultType, result, useBroadcastToShape))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class ConvertAtenContiguousOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenContiguousOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertAtenCopyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); Value self = adaptor.getSelf(); Value src = adaptor.getSrc(); RankedTensorType selfType = self.getType().cast(); // The non_blocking should be a constant `False`. bool nonBlocking; if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) { return rewriter.notifyMatchFailure( op, "unimplemented: non_blocking must be a constant"); } else if (nonBlocking) { return rewriter.notifyMatchFailure( op, "unimplemented: non_blocking is expected to be false"); } // The size of the src tensor can be different from the self but should be // broadcastable. Therefore, broadcasting the src tensor to match the size // of the self tensor. SmallVector selfSizes = getTensorSizes(rewriter, loc, self); for (unsigned i = 0; i < selfSizes.size(); i++) selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]); Value broadcastedSrc; if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, src, selfSizes, selfType, broadcastedSrc))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } AffineMap id = AffineMap::getMultiDimIdentityMap(selfType.getRank(), rewriter.getContext()); SmallVector iteratorTypes( selfType.getRank(), utils::IteratorType::parallel); Value result = rewriter .create( loc, /*resultType=*/selfType, /*inputs=*/broadcastedSrc, /*outputs=*/self, /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { Value result = args[0]; if (args[0].getType() != args[1].getType()) { result = convertScalarToDtype(b, loc, args[0], args[1].getType()); } b.create(loc, result); }) ->getResult(0); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenSliceScatterOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); SmallVector resultShape; SmallVector offsets; SmallVector strides; if (failed(prepareArgumentsForSlicingOp( op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } Value src = adaptor.getSrc(); auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = rewriter.create(loc, abstractSrcType, src); Value result = rewriter.create( loc, abstractSrc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace namespace { class ConvertAtenViewAsComplexOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenViewAsComplexOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); MLIRContext *context = rewriter.getContext(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op.getType()).cast(); auto elementType = resultType.getElementType(); SmallVector resultShape; for (int64_t i = 0; i < resultType.getRank(); i++) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(resultShape), elementType); SmallVector outputExpr; for (unsigned i = 0; i < resultType.getRank(); i++) { outputExpr.push_back(getAffineDimExpr(i, context)); } Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); Value constantOne = getConstant(rewriter, loc, 1, mlir::IndexType::get(context)); AffineMap outputMap = AffineMap::get(resultType.getRank(), 0, outputExpr, op->getContext()); SmallVector indexingMaps{outputMap}; SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); auto complexVar = rewriter .create( loc, outTensor.getType(), ValueRange{}, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector indicesZero; SmallVector indicesOne; for (int i = 0; i < resultType.getRank(); i++) { indicesZero.push_back(b.create(loc, i)); indicesOne.push_back(b.create(loc, i)); } indicesZero.push_back(constantZero); indicesOne.push_back(constantOne); Value realVal = b.create(loc, input, indicesZero); Value imagVal = b.create(loc, input, indicesOne); Value complexVal = b.create( loc, elementType, realVal, imagVal); b.create(loc, complexVal); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, complexVar); return success(); } }; } // namespace namespace { class ConvertAtenViewAsRealOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenViewAsRealOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); const TypeConverter *typeConverter = getTypeConverter(); MLIRContext *context = rewriter.getContext(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op.getType()).cast(); RankedTensorType inputType = input.getType().cast(); auto inputElementType = getElementTypeOrSelf(input.getType()); if (!inputElementType.isa()) { return op.emitError("only ComplexType is allowed as input type"); } Type elementType = resultType.getElementType(); // returned real tensor has a size increase, where the last dim has size 2 SmallVector resultShape = tensor::getMixedSizes(rewriter, loc, input); resultShape.push_back( rewriter.createOrFold(loc, 2)); Value outTensor = rewriter.create(loc, resultShape, elementType); SmallVector inputExpr; for (unsigned i = 0; i < resultType.getRank() - 1; i++) { inputExpr.push_back(getAffineDimExpr(i, context)); } AffineMap inputMap = AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); inputExpr.push_back(getAffineDimExpr(resultType.getRank() - 1, context)); AffineMap outputMap = AffineMap::get(resultType.getRank(), 0, inputExpr, op->getContext()); SmallVector indexingMaps{inputMap, outputMap}; SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); auto realVar = rewriter .create( loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value realVal = b.create(loc, elementType, args[0]); Value imagVal = b.create(loc, elementType, args[0]); Value lastIndex = b.create(loc, inputType.getRank()); Value cmpResult = b.create( loc, arith::CmpIPredicate::eq, lastIndex, constantZero); Value yieldValue = b.create( loc, cmpResult, realVal, imagVal); b.create(loc, yieldValue); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, realVar); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }