//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static FailureOr> shapeFunctionArgsBuilder(OpBuilder &b, Location loc, ValueRange originalOperands, func::FuncOp shapeFunc) { // Massage the op operands to match the shape function signature. // The shape function generally takes the same operands as the op, with a few // systematic modifications, such as replacing tensors with their shapes. SmallVector shapeFuncArgs; for (auto operandAndDesiredType : llvm::zip(originalOperands, shapeFunc.getArgumentTypes())) { Value operand; Type desiredType; std::tie(operand, desiredType) = operandAndDesiredType; FailureOr shapeFuncArg = adjustFunctionArg( b, loc, operand, desiredType, [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { // The shape library functions have tensor operands replaced with // `!torch.list` types for the shape. Get the sizes. auto desiredListType = desiredType.dyn_cast(); if (!desiredListType) return operand; if (operand.getType().isa() && desiredListType.getContainedType().isa()) { return b.create(loc, desiredType, operand); } return operand; }); if (failed(shapeFuncArg)) return failure(); shapeFuncArgs.push_back(*shapeFuncArg); } return shapeFuncArgs; } namespace { class ReifyShapeCalculationsPass : public ReifyShapeCalculationsBase { void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); // TODO: Find a way to not have to parse this every time. // The library is O(#ops we know about), and this pass should be // O(#ops in the program) ideally. OwningOpRef library = parseSourceString(getAbstractInterpLibrary(), context); // Walk all the operations, and if we have a shape function, wrap the op // in a `torch.shape.calculate` op. SmallVector functionsNeeded; WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { return wrapWithCalculateOpIfLibraryFunctionAvailable( op, *library, LibraryFunctionKind::ShapeFunction, functionsNeeded, shapeFunctionArgsBuilder); }); if (walkResult.wasInterrupted()) return signalPassFailure(); importLibraryFunctions(module, *library, std::move(functionsNeeded)); } }; } // namespace std::unique_ptr> mlir::torch::Torch::createReifyShapeCalculationsPass() { return std::make_unique(); }