//===----------------------------------------------------------------------===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Parser/Parser.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; std::string mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) return "__torch_mlir_shape_fn."; else if (libFuncKind == LibraryFunctionKind::DtypeFunction) return "__torch_mlir_dtype_fn."; else if (libFuncKind == LibraryFunctionKind::HasValueSemantics) return "__torch_mlir_has_value_semantics_fn."; llvm_unreachable( "`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`"); } static Operation *createCalculateOp(OpBuilder &b, Location loc, TypeRange resultTypes, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) return b.create(loc, resultTypes); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) return b.create(loc, resultTypes); llvm_unreachable( "`createCalculateOp` called with an unsupported `LibraryFunctionKind`"); } static Operation *createCalculateYieldOp(OpBuilder &b, Location loc, ValueRange results, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) return b.create(loc, results); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) return b.create(loc, results); llvm_unreachable("`createCalculateYieldOp` called with an unsupported " "`LibraryFunctionKind`"); } static Operation * createCalculateYieldCalculationOp(OpBuilder &b, Location loc, ValueRange results, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) return b.create(loc, results); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) return b.create(loc, results); llvm_unreachable("`createCalculateYieldCalculationOp` called with an " "unsupported `LibraryFunctionKind`"); } LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( Operation *op, ModuleOp library, LibraryFunctionKind libFuncKind, SmallVector &libFuncNamesUsed, function_ref>(OpBuilder &, Location, ValueRange, func::FuncOp)> libFuncArgsBuilder) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); auto name = op->getName().stripDialect(); // For value-semantic variant ops, i.e. valsem-ops (ops that are // mechanically consistent with existing torch conventions of in-place vs. // out-of-place (value-semantic) variants), remove the prefix when // looking them up in the library. if (name.starts_with("valsem.")) name = name.drop_front(strlen("valsem.")); if (isa(op)) name = cast(cast(op)->getAttr("name")).getValue(); std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); if (!libFunc) return success(); libFuncNamesUsed.push_back(libFuncName); OpBuilder b(op); Operation *calculate = createCalculateOp(b, loc, op->getResultTypes(), libFuncKind); op->replaceAllUsesWith(calculate); { // Move the op into the body of the `torch.{libFuncType}.calculate` op // and yield its results. OpBuilder b(context); Block *bodyBlock = b.createBlock(&calculate->getRegion(0)); op->moveBefore(bodyBlock, bodyBlock->end()); b.setInsertionPointAfter(op); createCalculateYieldOp(b, loc, op->getResults(), libFuncKind); } { OpBuilder b(context); b.createBlock(&calculate->getRegion(1)); // Create the call to the library function! FailureOr> libFuncArgs = libFuncArgsBuilder(b, loc, op->getOperands(), libFunc); if (failed(libFuncArgs)) return failure(); auto call = b.create(loc, libFunc, *libFuncArgs); // Python models multiple results with a tuple, so we need to unpack it // if the op has multiple results. SmallVector unpackedResults; assert(call.getNumResults() == 1 && "Multiple results are packed in a tuple in Python!"); Value result = call.getResult(0); if (auto tupleType = dyn_cast(result.getType())) { auto unpack = b.create( loc, tupleType.getContainedTypes(), result); llvm::append_range(unpackedResults, unpack.getResults()); } else { unpackedResults.push_back(result); } // Terminate the region. createCalculateYieldCalculationOp(b, loc, unpackedResults, libFuncKind); } return success(); } void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library, SmallVector functionsNeeded) { // Import just the functions we need. This includes transitive callees, // so we use a worklist algorithm. llvm::StringSet<> importedFunctions; while (!functionsNeeded.empty()) { std::string symName = functionsNeeded.pop_back_val(); if (importedFunctions.contains(symName)) continue; auto func = library.lookupSymbol(symName); assert(func && "broken library"); // Move the function from the library to the module this pass // is running on. (this mutates the library, but we re-parse it each time // so this is safe to do). func->moveBefore(&module.getBody()->front()); // Set the visibility to private so that the functions go away // nicely after we are done with them. func.setVisibility(SymbolTable::Visibility::Private); // Continue the DFS. importedFunctions.insert(symName); func.walk([&](func::CallOp op) { functionsNeeded.push_back(op.getCallee().str()); }); } } FailureOr Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, Type desiredType, function_ref baseTransformation) { operand = baseTransformation(b, loc, operand, desiredType); // No need for adjustment if they already match. auto operandType = operand.getType(); if (operandType == desiredType) return operand; if (isa(desiredType)) { // Generator's are currently passed as Any because TorchScript cannot // compile a function with Generator type arguments. // Ignoring that hack, this is a correct handling of Any type should we need // to actually support it in the future. return b.create(loc, desiredType, operand).getResult(); } // The type `!torch.number` can be an `int`, `float`, or `complex`. // TODO: Add a new type `Torch::ComplexType` to handle the complex case. if (isa(desiredType) && isa(operandType)) { return b.create(loc, desiredType, operand).getResult(); } // !torch.union is the type used for optional // `Scalar` inputs. At compile time, such inputs will usually be // resolved to an `int`, `float`, or `None` so we need to derefine // to match the library function signature. if (auto unionType = dyn_cast(desiredType)) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { return isa( containedType); })) return b.create(loc, desiredType, operand).getResult(); } // Operands with type `!torch.none` correspond to library function inputs with // types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the // type is derefined to match the expected type of the library function. if (isa(operandType)) { assert(!isa(desiredType) && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } // To keep things simple in shape functions, `Scalar` inputs are considered // `float`s. This is safe since output shape of torch ops never depends on the // dtype of input scalars. However, this also means we sometimes have to // manually turn `Scalar`s into `float`s when inserting the shape functions // into the IR. if (isa(operandType) && isa(desiredType)) { return b.create(loc, desiredType, operand).getResult(); } // If the operand type is statically !torch.optional, then we need to do // different things for the None and non-None cases. // For the None case, we just need to derefine it to the desired type. // For the non-None case, we need to unwrap the optional type and then adjust // it recursively (which also takes care of derefining it to ultimate desired // type). // A case where this happens is `!torch.optional` -> // `!torch.optional>>`. if (auto operandOptionalType = dyn_cast(operandType)) { if (isa(desiredType)) { // if optional is None: // return derefine(None) // else: // return adjust(unchecked_cast(optional)) auto none = b.create(loc); auto isNone = b.create(loc, operand, none); auto primIf = b.create(loc, desiredType, isNone); { Region &thenRegion = primIf.getThenRegion(); b.createBlock(&thenRegion, thenRegion.end()); auto derefineNone = b.create(loc, desiredType, none); b.create(loc, ValueRange{derefineNone}); } { Region &elseRegion = primIf.getElseRegion(); b.createBlock(&elseRegion, elseRegion.end()); auto downcasted = b.create( loc, operandOptionalType.getContainedType(), operand); FailureOr adjusted = adjustFunctionArg( b, loc, downcasted, desiredType, baseTransformation); if (failed(adjusted)) return failure(); b.create(loc, *adjusted); } b.setInsertionPointAfter(primIf); return primIf.getResult(0); } } // If the desired type is OptionalType, then recursively adjust the operand to // the contained type, then derefine it to `!torch.optional`. For example, // `!torch.vtensor -> !torch.optional>>`. if (auto desiredOptionalType = dyn_cast(desiredType)) { FailureOr adjusted = adjustFunctionArg( b, loc, operand, desiredOptionalType.getContainedType(), baseTransformation); if (failed(adjusted)) return failure(); return b.create(loc, desiredType, *adjusted).getResult(); } if (auto desiredListType = dyn_cast(desiredType)) { // Pseudocode: // // operand = ... // adjusted_list = [] // for i in range(len(operand)): // adjusted_list.append(adjust(operand[i])) // return adjusted_list auto providedType = cast(operand.getType()); Value adjustedList = b.create(loc, desiredListType, ValueRange({})); // Create a for-like PrimLoopOp. Value maxTripCount = b.create(loc, operand); Value cTrue = b.create(loc, true); auto loop = b.create(loc, TypeRange({}), maxTripCount, /*initialCondition=*/cTrue, /*iterArgsInit=*/ValueRange({})); // Create the loop body. { OpBuilder::InsertionGuard guard(b); Block *body = b.createBlock(&loop.getRegion(), loop.getRegion().begin(), TypeRange({b.getType()}), {loc}); Value iterationNumber = body->getArgument(0); Value element = b.create( loc, providedType.getContainedType(), operand, iterationNumber); FailureOr adjustedElement = adjustFunctionArg(b, loc, element, desiredListType.getContainedType(), baseTransformation); if (failed(adjustedElement)) return failure(); b.create(loc, adjustedList.getType(), adjustedList, *adjustedElement); b.create(loc, /*shouldContinue=*/cTrue, /*iterArgs=*/ValueRange({})); } return adjustedList; } // The library functions use `float` where the operator // signature uses `Scalar` (see comments in torch_ods_gen.py for // explanation). if (isa(desiredType) && isa(operand.getType())) { return b.create(loc, desiredType, operand).getResult(); } // Pass the operand as-is. return operand; } LogicalResult mlir::torch::Torch::loadExtraLibrary(const std::string &filename, OwningOpRef &moduleToAppendTo) { auto ctx = moduleToAppendTo->getContext(); assert(ctx && "Module should be fully initialized."); llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; return failure(); } llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); OwningOpRef module_ = mlir::parseSourceFile(sourceMgr, ctx); if (!module_) { llvm::errs() << "Error can't load file " << filename << "\n"; return failure(); } assert((moduleToAppendTo->getBodyRegion().empty() || moduleToAppendTo->getBodyRegion().hasOneBlock()) && "Module should have at most one block."); if (moduleToAppendTo->getBodyRegion().empty()) { moduleToAppendTo = std::move(module_); } else { Block *block = moduleToAppendTo->getBody(0); block->getOperations().splice(block->end(), module_->getBody(0)->getOperations()); } return success(); }