//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class ConvertAtenConstantPadNdOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenConstantPadNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value self = adaptor.getSelf(); auto type = self.getType().cast(); int64_t rank = type.getRank(); auto primList = op.getPad().getDefiningOp(); if (!primList) { return rewriter.notifyMatchFailure(op, "unable to get pad values"); } SmallVector padVals(primList.getOperands()); uint64_t padRank = padVals.size() / 2; if (padRank * 2 != padVals.size()) return rewriter.notifyMatchFailure(op, "pad range size is not even"); if (rank < 0 || padRank > (uint64_t)rank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); // Initialize low/high paddings with the dims that should not be padded. int64_t noPad = rank - padRank; Attribute zero = rewriter.getIndexAttr(0); SmallVector staticLow(noPad, 0); SmallVector staticHigh(noPad, 0); SmallVector lowPad(noPad, zero); SmallVector highPad(noPad, zero); auto tc = getTypeConverter(); // Add the requested padding - note op.pad() is highest dim first ordered // pairs of low,high. for (uint64_t i = padRank; i > 0; --i) { int64_t lowi, highi; Value lowv = padVals[i * 2 - 2]; Value highv = padVals[i * 2 - 1]; if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) { Type cty = tc->convertType(lowv.getType()); lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv); lowv = rewriter.create(loc, rewriter.getIndexType(), lowv); lowPad.push_back(lowv); staticLow.push_back(ShapedType::kDynamic); } else { lowPad.push_back(rewriter.getIndexAttr(lowi)); staticLow.push_back(lowi); } if (!matchPattern(highv, m_TorchConstantInt(&highi))) { Type cty = tc->convertType(highv.getType()); highv = tc->materializeTargetConversion(rewriter, loc, cty, highv); highv = rewriter.create( loc, rewriter.getIndexType(), highv); highPad.push_back(highv); staticHigh.push_back(ShapedType::kDynamic); } else { highPad.push_back(rewriter.getIndexAttr(highi)); staticHigh.push_back(highi); } } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); Value castedValue = convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); Type padType = tensor::PadOp::inferResultType( self.getType().cast(), staticLow, staticHigh); Value paddedInput = rewriter.create( loc, padType, self, lowPad, highPad, castedValue); rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); return success(); } }; } // namespace namespace { // Lower aten.replication_pad2d operator into a sequence of // tensor.extract_slice and tensor.concat operations. class ConvertAtenReplicationPad2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.getSelf(); auto inputType = llvm::cast(input.getType()); int64_t inputRank = inputType.getRank(); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( op, "only support constant int pad ranges"); uint64_t padRank = padInts.size() / 2; if (padRank * 2 != padInts.size()) return rewriter.notifyMatchFailure(op, "pad range size is not even"); if (inputRank < 0 || padRank > (uint64_t)inputRank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); SmallVector inputShape = getTensorSizes(rewriter, loc, input); int64_t hDim = numDims - 1; int64_t vDim = numDims - 2; Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, }; // vTile denotes the vertical size of the tile // hTile denotes the horizontal size of the tile // The padding results are composed of following tiles: // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], // vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT], // vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] // vTile[VCENTER]hTile[HCENTER] is the original input tensor Type indexType = rewriter.getIndexType(); Value vTile[3]; Value hTile[3]; vTile[VCENTER] = vDimSize; hTile[HCENTER] = hDimSize; vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); bool hasLeftPadding = false; bool hasRightPadding = false; bool hasTopPadding = false; bool hasBottomPadding = false; for (auto i : {TOP, VCENTER, BOTTOM}) { for (auto j : {LEFT, HCENTER, RIGHT}) { auto constVtile{ mlir::dyn_cast(vTile[i].getDefiningOp()) .getValue() .dyn_cast_or_null()}; auto constHtile{ mlir::dyn_cast(hTile[j].getDefiningOp()) .getValue() .dyn_cast_or_null()}; auto vSize = constVtile.getInt(); auto hSize = constHtile.getInt(); if ((i == TOP) && (vSize > 0)) hasTopPadding = true; if ((i == BOTTOM) && (vSize > 0)) hasBottomPadding = true; if ((j == LEFT) && (hSize > 0)) hasLeftPadding = true; if ((j == RIGHT) && (hSize > 0)) hasRightPadding = true; } } auto createSub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; // Extract left and right pad tiles. Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); Value hDimSizeMinusOne = createSub(hDimSize, one); Value vDimSizeMinusOne = createSub(vDimSize, one); SmallVector allOneStrides(numDims, one); SmallVector extractOffsetsLT(numDims, zero); extractOffsetsLT[hDim] = zero; extractOffsetsLT[vDim] = zero; SmallVector extractShapeLR(numDims, one); extractShapeLR[hDim] = one; extractShapeLR[vDim] = vDimSize; SmallVector extractOffsetsRight(numDims, zero); extractOffsetsRight[hDim] = hDimSizeMinusOne; extractOffsetsRight[vDim] = zero; SmallVector extractOffsetsBottom(numDims, zero); extractOffsetsBottom[hDim] = zero; extractOffsetsBottom[vDim] = vDimSizeMinusOne; SmallVector extractShapeTB(numDims, one); extractShapeTB[hDim] = hDimSize; extractShapeTB[vDim] = one; SmallVector tensorsLeft; SmallVector tensorsRight; SmallVector tensorsCenter; Value centerTile; SmallVector tensorsRes; if (hasLeftPadding) { Value vCenterLeftSlice = rewriter.create( loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); Value vLeftSlice = vCenterLeftSlice; if (hasTopPadding) { Value topLeftValue = rewriter.create( loc, input, ValueRange{zero, zero, zero, zero}); // pad vCenterLeftSlice on the top SmallVector lowPadding(4, 0); SmallVector highPadding(4, 0); lowPadding[2] = padInts[2]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); } if (hasBottomPadding) { Value bottomLeftValue = rewriter.create( loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); // pad vLeftSlice at the bottom SmallVector lowPadding(4, 0); SmallVector highPadding(4, 0); highPadding[2] = padInts[3]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); } for (auto i = 0; i < padInts[0]; ++i) { tensorsLeft.push_back(vLeftSlice); } Value leftPadTile = rewriter.create(loc, 3, tensorsLeft); tensorsRes.push_back(leftPadTile); } if (hasTopPadding) { Value topHcenterSlice = rewriter.create( loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); for (auto i = 0; i < padInts[2]; ++i) { tensorsCenter.push_back(topHcenterSlice); } } tensorsCenter.push_back(input); if (hasBottomPadding) { Value bottomHcenterSlice = rewriter.create( loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); for (auto i = 0; i < padInts[3]; ++i) { tensorsCenter.push_back(bottomHcenterSlice); } } centerTile = rewriter.create(loc, 2, tensorsCenter); tensorsRes.push_back(centerTile); if (hasRightPadding) { Value vCenterRightSlice = rewriter.create( loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); Value vRightSlice = vCenterRightSlice; if (hasTopPadding) { Value topRightValue = rewriter.create( loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); // pad vCenterRightSlice on the top SmallVector lowPadding(4, 0); SmallVector highPadding(4, 0); lowPadding[2] = padInts[2]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); } if (hasBottomPadding) { Value bottomRightValue = rewriter.create( loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. SmallVector lowPadding(4, 0); SmallVector highPadding(4, 0); highPadding[2] = padInts[3]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); } for (auto i = 0; i < padInts[1]; ++i) { tensorsRight.push_back(vRightSlice); } Value rightPadTile = rewriter.create(loc, 3, tensorsRight); tensorsRes.push_back(rightPadTile); } Value resTensor = rewriter.create(loc, 3, tensorsRes); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, resTensor); return success(); } }; } // namespace namespace { // Converts constant tensor allocation like ops. template class ConvertConstantTensorAllocOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Add support for layout, pin_memory features. // Only `none` layout is supported. // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; if (!op.getPinMemory().getType().template isa() && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); } Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( op, "unimplemented: size must be constructed using ListConstruct"); } resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, resultSizeTorchInt); for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = typeConverter->convertType(op.getType()) .template cast(); Type resultElementType; if (op.getDtype().getType().template isa()) { resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = torch_to_linalg::getBackendTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape and fill it with // value `fillVal`. Value constVal = getConstant(rewriter, loc, fillVal, resultElementType); Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal); rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } }; } // namespace namespace { // Converts `aten.empty` to `linalg.init_tensor` op. class ConvertAtenEmptyMemoryFormatOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Add support pin_memory and memory_format features. // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; if (!op.getPinMemory().getType().template isa() && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Preserve) return rewriter.notifyMatchFailure( op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( op, "unimplemented: device must be a constant str"); else if (device != "cpu") return rewriter.notifyMatchFailure( op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( op, "unimplemented: layout must be a constant"); else if (tensorLayout != torch_upstream::Layout::Strided) return rewriter.notifyMatchFailure( op, "unimplemented: layout is expected to be strided"); } Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); SmallVector resultSizeTorchInt, resultSize, resultSizeIndex; if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) { return rewriter.notifyMatchFailure( op, "unimplemented: size must be constructed using ListConstruct"); } resultSize = getTypeConvertedValues(rewriter, loc, typeConverter, resultSizeTorchInt); for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = typeConverter->convertType(op.getType()).cast(); Type resultElementType; if (op.getDtype().getType().isa()) { resultElementType = getDefaultDtypeForTorchScalar( Torch::FloatType::get(op->getContext())); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = torch_to_linalg::getBackendTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape. Value initTensor = rewriter.create( loc, getAsOpFoldResult(resultSizeIndex), resultElementType); rewriter.replaceOpWithNewOp(op, resultType, initTensor); return success(); } }; } // namespace namespace { // Let's say the result of the `aten.arange.start_step` is `output` which is a // 1-d output tensor. The approach used for generating the output tensor is as // follows: // for i in range(ceil((end-start)/step)) // output[i] = start + (i * step) class ConvertAtenArangeStartStepOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Add support for pin_memory features. // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; if (!op.getPinMemory().getType().isa() && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); } Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); Type dtype = resultType.getElementType(); Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype); Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) Value resultShape; if (dtype.isa()) { Value subOut = rewriter.create(loc, end, start); resultShape = rewriter.create(loc, subOut, step); } else { Value subOut = rewriter.create(loc, end, start); Value divOut = rewriter.create(loc, subOut, step); Value ceilOut = rewriter.create(loc, divOut); resultShape = rewriter.create(loc, rewriter.getI64Type(), ceilOut); } resultShape = castIntToIndex(rewriter, loc, resultShape); Value resultTensor = rewriter.create( loc, getAsOpFoldResult(resultShape), dtype); auto iteratorType = utils::IteratorType::parallel; AffineMap indexingMap = AffineMap::getMultiDimIdentityMap(1, op->getContext()); Value finalRes = rewriter .create( loc, /*resultTensorTypes=*/resultTensor.getType(), /*inputs=*/ValueRange({}), /*outputs=*/resultTensor, /*indexingMaps=*/indexingMap, /*iteratorTypes=*/iteratorType, [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value index = b.create(loc, 0); index = castIndexToInt64(b, loc, index); index = convertScalarToDtype(b, loc, index, dtype); Value mulOut, result; if (dtype.isa()) { mulOut = b.create(loc, step, index); result = b.create(loc, start, mulOut); } else { mulOut = b.create(loc, step, index); result = b.create(loc, start, mulOut); } b.create(loc, result); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); } }; } // namespace void mlir::torch::torch_to_linalg:: populateTensorConstructorsPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); }