//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "../PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/MLIRContext.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TMTensor; // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) // ----------------------------------------------------------------------------- // This is going to eventually be O(#aten ops), which is in the 100s. // // Most of these patterns consist of: // 1. Checking that the operand/result types and other static properties are // good-enough to create a valid linalg op (such as operands being of // ranks/dtypes acceptable to the linalg op). // 2. Creating dynamic error guards, usually checking a predicate on the // compatibility of operand shapes. // 3. Creating init tensors for the computation op. Usually this involves // reifying IR for a shape transfer function based on the operand shapes. // 4. Creating a named linalg op to replace the original op. // // TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. class ConvertAtenBincountOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBincountOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); TypeConverter *typeConverter = getTypeConverter(); Value input = adaptor.self(); Value torchTypeInput = op.self(); Value minlength = adaptor.minlength(); Value weights = adaptor.weights(); // TODO: Add a check to verify that the input tensor elements are all // non-negative. // Check whether the input is a 1-d tensor of integer type or not. RankedTensorType inputType = input.getType().cast(); if (inputType.getRank() != 1 || !inputType.getElementType().isa()) return rewriter.notifyMatchFailure( op, "Input tensor has to be a one-dimensional tensor of integer type."); // Check whether the input tensor element type is i64 or not. IntegerType inputIntegerType = inputType.getElementType().cast(); if (inputIntegerType.getWidth() != 64) return rewriter.notifyMatchFailure( op, "Unimplemented: Integer width not equal to 64 are not supported."); // TODO: Incorporate the weight argument. if (!weights.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weights operand is not incorporated."); // Finding the maximum value in the input tensor. SmallVector maxTensorSizes; ValueTensorType maxTensorType = ValueTensorType::get( context, llvm::makeArrayRef(maxTensorSizes), torchTypeInput.getType().cast().getDtype()); Value maxTensor = rewriter.create(loc, maxTensorType, torchTypeInput); maxTensor = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(maxTensor.getType()), maxTensor); // `maxTensor` is a 0-d tensor, extracting its only element and // storing it in `maxInput`. Value maxInput = rewriter.create(loc, maxTensor); // Creating a tm_tensor.scatter op with the following mapping: // 1.) `input` tensor maps to the indices in scatter op. `input` is // expanded from 1-d to 2-d, and its element type is set to i32 as required // for the scatter op. // 2.) `updates` is a 1-d dummy tensor with the size equivalent to the // `input`. // 3.) `bincount` a 1-d tensor maps to the original in scatter op // with size equal to the max(max(input) + 1, minlength). SmallVector expandedInputSizes{inputType.getShape()[0], 1}; ValueTensorType expandInputType = ValueTensorType::get( context, llvm::makeArrayRef(expandedInputSizes), torchTypeInput.getType().cast().getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedInputTensor = rewriter.create( loc, expandInputType, torchTypeInput, torchCstOne); // Converting the input element type to i32. Value indices = convertTensorToDtype( rewriter, loc, expandedInputTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); Type resultElemType = typeConverter->convertType(op->getResult(0).getType()) .cast() .getElementType(); SmallVector inputSizeDynamic = getTensorSizesUntilDim(rewriter, loc, input, 0); Value updatesTensor = rewriter.create( loc, getAsOpFoldResult(inputSizeDynamic), resultElemType); Value constantZero = rewriter.create( loc, rewriter.getZeroAttr(resultElemType)); Value constantOne = rewriter.create( loc, 1, resultElemType.getIntOrFloatBitWidth()); // Bincount size = max(max(input) + 1, minlength) Value maxInputPlusOne = rewriter.create(loc, maxInput, constantOne); Value bincountSize = rewriter.create(loc, maxInputPlusOne, minlength); bincountSize = castIntToIndex(rewriter, loc, bincountSize); Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize}, resultElemType, constantZero); auto scatterOp = rewriter.create( loc, bincountTensor.getType(), ValueRange{updatesTensor, indices}, ValueRange{bincountTensor}, /*unique_indices=*/false); Region &scatterOpRegion = scatterOp.region(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType}, {loc, loc}); auto blockArgs = scatterOpBlock.getArguments(); // Creating an add instruction inside the scatter op region to increment the // frequency counter with one. OpBuilder regionBuilder(scatterOpRegion); Value add = regionBuilder.create(loc, /*bincount=*/blockArgs[1], constantOne); regionBuilder.create(loc, add); rewriter.replaceOp(op, scatterOp->getResult(0)); return success(); } }; } // namespace namespace { class ConvertValsemVariantAtenIndexPutImplOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ValsemVariantAtenIndexPutImplOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.self(); Value values = adaptor.values(); RankedTensorType inputType = input.getType().cast(); RankedTensorType valuesType = values.getType().cast(); Type resultElemType = typeConverter->convertType(op->getResult(0).getType()) .cast() .getElementType(); // TODO: Add support for the input with rank other than one. if (inputType.getRank() != 1) return rewriter.notifyMatchFailure( op, "unimplemented: input rank other than one is not supported"); // The unsafe should be either `False` or `none`. if (!op.unsafe().getType().isa()) { bool unsafe; if (!matchPattern(op.unsafe(), m_TorchConstantBool(&unsafe))) return rewriter.notifyMatchFailure( op, "unimplemented: unsafe must be a constant"); else if (unsafe) return rewriter.notifyMatchFailure( op, "unimplemented: unsafe is expected to be false"); } // The accumulate should be a torch constant of boolean type. bool accumulate; if (!matchPattern(op.accumulate(), m_TorchConstantBool(&accumulate))) return rewriter.notifyMatchFailure( op, "Expected accumulate to be constant bool."); // The element type of the `input` and `values` should be same. if (inputType.getElementType() != valuesType.getElementType()) return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); SmallVector indicesList; getListConstructElements(adaptor.indices(), indicesList); // The size of the list of the index tensors should not be greater than the // input rank. if ((int64_t)indicesList.size() > inputType.getRank()) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); // TODO: Add support for cases with indices list size smaller than the input // rank. if ((int64_t)indicesList.size() < inputType.getRank()) return rewriter.notifyMatchFailure( op, "Unimplemented, Indices list size smaller than input rank"); if (indicesList[0].getType().isa()) return rewriter.notifyMatchFailure(op, "Indices tensor must not be none."); // TODO: Add support for the index with rank other than one. int64_t indexRank = typeConverter->convertType(indicesList[0].getType()) .cast() .getRank(); if (indexRank != 1) return rewriter.notifyMatchFailure( op, "unimplemented: index rank other than one is not supported"); // Creating a tm_tensor.scatter op with the following mapping: // 1.) Index tensor from the `indicesList` maps to the indices in scatter // op. Index tensor is expanded from 1-d to 2-d, and its element type is set // to i32 as required for the scatter op. // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. ValueTensorType indexType = indicesList[0].getType().cast(); SmallVector expandedIndexSizes{indexType.getSizes()[0], 1}; ValueTensorType expandedIndexType = ValueTensorType::get( context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedIndexTensor = rewriter.create( loc, expandedIndexType, indicesList[0], torchCstOne); // Converting the index element type to i32. Value indices = convertTensorToDtype( rewriter, loc, expandedIndexTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); auto scatterOp = rewriter.create( loc, input.getType(), ValueRange{values, indices}, ValueRange{input}, /*unique_indices=*/false); Region &scatterOpRegion = scatterOp.region(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType}, {loc, loc}); auto blockArgs = scatterOpBlock.getArguments(); OpBuilder regionBuilder(scatterOpRegion); Value update = blockArgs[0]; Value original = blockArgs[1]; Value yieldValue = update; // Create an add instruction inside the scatter op region to increment the // `original` value with the value from `updates` if the accumulate flag is // true. if (accumulate) { if (inputType.getElementType().isa()) yieldValue = regionBuilder.create(loc, original, update); else if (inputType.getElementType().isa()) yieldValue = regionBuilder.create(loc, original, update); } regionBuilder.create(loc, yieldValue); rewriter.replaceOp(op, scatterOp->getResult(0)); return success(); } }; } // namespace // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- namespace { class ConvertTorchToTMTensor : public ConvertTorchToTMTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::torch::createConvertTorchToTMTensorPass() { return std::make_unique(); }