//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TMTensor; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#aten ops), which is in the 100s. // // Most of these patterns consist of: // 1. Checking that the operand/result types and other static properties are // good-enough to create a valid linalg op (such as operands being of // ranks/dtypes acceptable to the linalg op). // 2. Creating dynamic error guards, usually checking a predicate on the // compatibility of operand shapes. // 3. Creating init tensors for the computation op. Usually this involves // reifying IR for a shape transfer function based on the operand shapes. // 4. Creating a named linalg op to replace the original op. // // TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". static Value createTMTensorScatterOp( OpBuilder &b, Location loc, Value updates, Value indices, Value original, bool uniqueIndices, function_ref bodyBuild) { auto originalTensorType = original.getType().cast(); Type originalElementType = originalTensorType.getElementType(); auto scatterOp = b.create( loc, originalTensorType, ValueRange{updates, indices}, ValueRange{original}, uniqueIndices); Region &scatterOpRegion = scatterOp.getRegion(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); scatterOpBlock.addArguments({originalElementType, originalElementType}, {loc, loc}); OpBuilder regionBuilder(scatterOpRegion); auto blockArgs = scatterOpBlock.getArguments(); Value updatesElement = blockArgs[0]; Value originalElement = blockArgs[1]; bodyBuild(regionBuilder, loc, updatesElement, originalElement); return scatterOp->getResult(0); } static Value createTMTensorScanOp( OpBuilder &b, Location loc, Value input, Value output, Value accumulator, int64_t dim, bool inclusive, function_ref bodyBuild) { auto inputType = input.getType().cast(); auto accType = accumulator.getType().cast(); Type elementType = inputType.getElementType(); auto scanOp = b.create( loc, TypeRange{inputType, accType}, input, ValueRange{output, accumulator}, b.getI64IntegerAttr(dim), b.getBoolAttr(inclusive)); Region &scanOpRegion = scanOp.getRegion(); auto &scanOpBlock = scanOpRegion.emplaceBlock(); scanOpBlock.addArguments({elementType, elementType}, {loc, loc}); OpBuilder regionBuilder(scanOpRegion); auto blockArgs = scanOpBlock.getArguments(); Value inputElement = blockArgs[0]; Value accElement = blockArgs[1]; bodyBuild(regionBuilder, loc, inputElement, accElement); return scanOp->getResult(0); } namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. class ConvertAtenBincountOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBincountOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); TypeConverter *typeConverter = getTypeConverter(); Value input = adaptor.getSelf(); Value torchTypeInput = op.getSelf(); Value minlength = adaptor.getMinlength(); Value weights = adaptor.getWeights(); // TODO: Add a check to verify that the input tensor elements are all // non-negative. // Check whether the input is a 1-d tensor of integer type or not. RankedTensorType inputType = input.getType().cast(); if (inputType.getRank() != 1 || !inputType.getElementType().isa()) return rewriter.notifyMatchFailure( op, "Input tensor has to be a one-dimensional tensor of integer type."); // Check whether the input tensor element type is i64 or not. IntegerType inputIntegerType = inputType.getElementType().cast(); if (inputIntegerType.getWidth() != 64) return rewriter.notifyMatchFailure( op, "Unimplemented: Integer width not equal to 64 are not supported."); // TODO: Incorporate the weight argument. if (!weights.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented: the weights operand is not incorporated."); // Finding the maximum value in the input tensor. SmallVector maxTensorSizes; ValueTensorType maxTensorType = ValueTensorType::get( context, llvm::ArrayRef(maxTensorSizes), torchTypeInput.getType().cast().getDtype()); Value maxTensor = rewriter.create(loc, maxTensorType, torchTypeInput); maxTensor = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(maxTensor.getType()), maxTensor); // `maxTensor` is a 0-d tensor, extracting its only element and // storing it in `maxInput`. Value maxInput = rewriter.create(loc, maxTensor); // Creating a tm_tensor.scatter op with the following mapping: // 1.) `input` tensor maps to the indices in scatter op. `input` is // expanded from 1-d to 2-d, and its element type is set to i32 as required // for the scatter op. // 2.) `updates` is a 1-d dummy tensor with the size equivalent to the // `input`. // 3.) `bincount` a 1-d tensor maps to the original in scatter op // with size equal to the max(max(input) + 1, minlength). SmallVector expandedInputSizes{ makeShapeTorchCompatible(inputType.getShape())[0], 1}; ValueTensorType expandInputType = ValueTensorType::get( context, llvm::ArrayRef(expandedInputSizes), torchTypeInput.getType().cast().getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedInputTensor = rewriter.create( loc, expandInputType, torchTypeInput, torchCstOne); // Converting the input element type to i32. Value indices = convertTensorToDtype( rewriter, loc, expandedInputTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); Type resultElemType = resultType.getElementType(); SmallVector inputSizeDynamic = getTensorSizesUntilDim(rewriter, loc, input, 0); Value updatesTensor = rewriter.create( loc, getAsOpFoldResult(inputSizeDynamic), resultElemType); Value constantZero = rewriter.create( loc, rewriter.getZeroAttr(resultElemType)); Value constantOne = rewriter.create( loc, 1, resultElemType.getIntOrFloatBitWidth()); // Bincount size = max(max(input) + 1, minlength) Value maxInputPlusOne = rewriter.create(loc, maxInput, constantOne); Value bincountSize = rewriter.create(loc, maxInputPlusOne, minlength); bincountSize = castIntToIndex(rewriter, loc, bincountSize); Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize}, resultElemType, constantZero); Value scatterOp = createTMTensorScatterOp( rewriter, loc, updatesTensor, indices, bincountTensor, /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { Value add = b.create(loc, bincountElem, constantOne); b.create(loc, add); }); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } }; } // namespace namespace { class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.getSelf(); Value values = adaptor.getValues(); RankedTensorType inputType = input.getType().cast(); RankedTensorType valuesType = values.getType().cast(); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); // The unsafe should be either `False` or `none`. if (!op.getUnsafe().getType().isa()) { bool unsafe; if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe))) return rewriter.notifyMatchFailure( op, "unimplemented: unsafe must be a constant"); else if (unsafe) return rewriter.notifyMatchFailure( op, "unimplemented: unsafe is expected to be false"); } // The accumulate should be a torch constant of boolean type. bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) return rewriter.notifyMatchFailure( op, "Expected accumulate to be constant bool."); // The element type of the `input` and `values` should be same. if (inputType.getElementType() != valuesType.getElementType()) return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); SmallVector indicesList; getListConstructElements(adaptor.getIndices(), indicesList); // The size of the list of the index tensors should not be greater than the // input rank. if ((int64_t)indicesList.size() > inputType.getRank()) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); // TODO: Add support for cases with indices list size not equal to 1. if (indicesList.size() != 1) return rewriter.notifyMatchFailure( op, "Unimplemented: Indices list size != 1"); Value indexTensor = indicesList[0]; if (indexTensor.getType().isa()) return rewriter.notifyMatchFailure(op, "Index tensor must not be None."); // Creating a tm_tensor.scatter op with the following mapping: // 1.) Index tensor from the `indicesList` maps to the indices in scatter // op. Index tensor is expanded from 1-d to 2-d, and its element type is set // to i32 as required for the scatter op. // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. std::optional indexTensorRank = getTensorRank(indexTensor); if (!indexTensorRank || *indexTensorRank != 1) return rewriter.notifyMatchFailure( op, "unimplemented: index tensor with rank != 1 is not supported"); auto indexTensorType = indexTensor.getType().cast(); int64_t indexTensorSize = indexTensorType.getSizes()[0]; SmallVector expandedIndexTensorSizes{indexTensorSize, 1}; ValueTensorType expandedIndexTensorType = ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes), indexTensorType.getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedIndexTensor = rewriter.create( loc, expandedIndexTensorType, indexTensor, torchCstOne); // `TMTensor::ScatterOp` expects indices of element type i32. Value indices = convertTensorToDtype( rewriter, loc, expandedIndexTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( rewriter, loc, values, indices, input, /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; if (accumulate) { if (inputElement.getType().isa()) { yieldValue = b.create(loc, inputElement, valuesElement); } else if (inputElement.getType().isa()) { yieldValue = b.create(loc, inputElement, valuesElement); } else { invalidInputTypeFound = true; return; } } b.create(loc, yieldValue); }); if (invalidInputTypeFound) { return rewriter.notifyMatchFailure( op, "unimplemented: input tensor must be of integer type or float type"); } rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } }; } // namespace namespace { // The original implementation of the op is as follows: // // Indices and GradOutput Layout: [N, C, H, W] or [C, H, W] // Input Layout: [N, C, Hin, Win] or [C, Hin, Win] // // for i in range(N): // for j in range(C): // for k in range(H): // for l in range(W): // index = indices[i, j, k, l] // result[i, j, index/Win, index%Win] += gradOutput[i, j, k, l] // // OR // // for i in range(C): // for j in range(H): // for k in range(W): // index = indices[i, j, k] // result[i, index/Win, index%Win] += gradOutput[i, j, k] // class ConvertAtenMaxPool2dWithIndicesBackwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMaxPool2dWithIndicesBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); Value gradOutput = adaptor.getGradOutput(); Value input = adaptor.getSelf(); RankedTensorType gradOutputType = gradOutput.getType().cast(); Type gradOutputElemType = gradOutputType.getElementType(); RankedTensorType inputType = input.getType().cast(); Type inputElemType = inputType.getElementType(); int64_t tensorOperandRank = inputType.getRank(); // `TMTensor::ScatterOp` expects indices of element type i32. Value indices = convertTensorToDtype( rewriter, loc, op.getIndices(), mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); RankedTensorType indicesType = indices.getType().cast(); Type indicesElemType = indicesType.getElementType(); // The element type of the `input` and `grad_output` should be same. if (inputElemType != gradOutputElemType) return rewriter.notifyMatchFailure( op, "Input element type should be same as the grad_output element type."); // Since the scatter op requires indices to be a 2-d tensor, we create a new // 5-d/4-d tensor (depending on the original indices layout) comprising the // index values. We will collapse this tensor into a 2-d tensor. The // algorithm for the creation of updated indices tensor is as follows: // // for i in range(N): // for j in range(C): // for k in range(H): // for l in range(W): // for m in range(4): // if m == 0: // updatedIndices[N][C][H][W][0] = i // if m == 1: // updatedIndices[N][C][H][W][1] = j // if m == 2: // updatedIndices[N][C][H][W][2] = // originalIndices[i, j, k, l] / Win // if m == 3: // updatedIndices[N][C][H][W][3] = // originalIndices[i, j, k, l] % Win // // OR // // for j in range(C): // for k in range(H): // for l in range(W): // for m in range(3): // if m == 0: // updatedIndices[C][H][W][0] = i // if m == 1: // updatedIndices[C][H][W][1] = originalIndices[i, j, k, l] / Win // if m == 2: // updatedIndices[C][H][W][2] = originalIndices[i, j, k, l] % Win SmallVector inputShape = getTensorSizes(rewriter, loc, input); SmallVector originalIndicesDimExprs, updatedIndicesDimExprs; for (int64_t i = 0; i < tensorOperandRank; i++) { originalIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i)); updatedIndicesDimExprs.push_back(rewriter.getAffineDimExpr(i)); } updatedIndicesDimExprs.push_back( rewriter.getAffineDimExpr(tensorOperandRank)); SmallVector indexingMaps = AffineMap::inferFromExprList( {originalIndicesDimExprs, updatedIndicesDimExprs}); SmallVector iteratorTypes( tensorOperandRank + 1, utils::IteratorType::parallel); SmallVector updatedIndicesShape = getAsOpFoldResult(getTensorSizes(rewriter, loc, indices)); updatedIndicesShape.push_back(rewriter.getIndexAttr(tensorOperandRank)); Value initTensor = rewriter.create( loc, updatedIndicesShape, indicesElemType); Value wIn = inputShape[tensorOperandRank - 1]; SmallVector cstValues; for (int64_t i = 0; i < tensorOperandRank; i++) cstValues.push_back(rewriter.create(loc, i)); Value updatedIndices = rewriter .create( loc, initTensor.getType(), indices, initTensor, indexingMaps, iteratorTypes, [tensorOperandRank, wIn, cstValues, indicesElemType](OpBuilder &b, Location loc, ValueRange args) { Value index = castIntToIndex(b, loc, args[0]); Value updatedIndex = cstValues[0]; Value lastDim = b.create(loc, tensorOperandRank); for (int64_t i = tensorOperandRank - 1; i >= 0; i--) { Value result; if (i == tensorOperandRank - 1) result = b.create(loc, index, wIn); if (i == tensorOperandRank - 2) result = b.create(loc, index, wIn); if (i == tensorOperandRank - 3 || i == tensorOperandRank - 4) result = b.create(loc, i); Value pred = b.create( loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]); Value addAmount = b.create( loc, pred, result, cstValues[0]); updatedIndex = b.create(loc, updatedIndex, addAmount); } updatedIndex = b.create( loc, indicesElemType, updatedIndex); b.create(loc, updatedIndex); }) .getResult(0); // Creating a new tensor initialized with zeros and size same as the input // tensor. Value outputTensor = createZeroInitTensor(rewriter, loc, inputShape, inputElemType); // Collapsing `gradOutput` into a 1-d tensor. SmallVector reassociationCollapse(1); for (auto i = 0; i < gradOutputType.getRank(); i++) reassociationCollapse[0].push_back(i); RankedTensorType gradOutputFlattenedType; int64_t numelGradOutput = getNumberOfElements(gradOutputType); gradOutputFlattenedType = RankedTensorType::get( makeShapeLLVMCompatible({numelGradOutput}), gradOutputElemType); Value gradOutputFlattened = rewriter.create( loc, gradOutputFlattenedType, gradOutput, reassociationCollapse); // Collapsing updated indices into a 2-d tensor. SmallVector reassociationCollapseIndices(2); for (auto i = 0; i < tensorOperandRank; i++) reassociationCollapseIndices[0].push_back(i); reassociationCollapseIndices[1].push_back(tensorOperandRank); int64_t numelIndices = getNumberOfElements(indicesType); Value indicesCollapsed = rewriter.create( loc, RankedTensorType::get( makeShapeLLVMCompatible({numelIndices, tensorOperandRank}), indicesElemType), updatedIndices, reassociationCollapseIndices); bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( rewriter, loc, /*updates=*/gradOutputFlattened, /*indices=*/indicesCollapsed, /*original=*/outputTensor, /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; if (inputElement.getType().isa()) { yieldValue = b.create(loc, inputElement, valuesElement); } else if (inputElement.getType().isa()) { yieldValue = b.create(loc, inputElement, valuesElement); } else { invalidInputTypeFound = true; return; } b.create(loc, yieldValue); }); if (invalidInputTypeFound) { return rewriter.notifyMatchFailure( op, "unimplemented: input tensor must be of integer type or float type"); } Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, scatterOp); return success(); } }; } // namespace namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); auto resultType = input.getType().cast(); Type elementType = resultType.getElementType(); int64_t inputRank = resultType.getRank(); Location loc = op->getLoc(); Value dtype = op.getDtype(); if (!dtype.getType().isa()) return rewriter.notifyMatchFailure( op, "unsupported: dtype argument not supported"); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "unimplemented: only constant dim value is supported"); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "invalid dim"); SmallVector sizes = getTensorSizes(rewriter, loc, input); Value output = createZeroInitTensor(rewriter, loc, sizes, elementType); output = rewriter.create(loc, resultType, output); SmallVector accSizes(sizes); accSizes.erase(accSizes.begin() + dim); SmallVector accStatic( makeShapeTorchCompatible(resultType.getShape())); accStatic.erase(accStatic.begin() + dim); Value acc = createZeroInitTensor(rewriter, loc, accSizes, elementType); Type accType = RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); acc = rewriter.create(loc, accType, acc); Value result = createTMTensorScanOp( rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { Value sum = (input.getType().isa() ? b.create(loc, input, acc) : b.create(loc, input, acc)) ->getResult(0); b.create(loc, sum); }); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToTMTensor : public ConvertTorchToTMTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToTMTensorPass() { return std::make_unique(); }