//===- MaximizeValueSemantics.cpp --------------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static bool isViewLikeOp(Operation *op) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value // semantics. return isa(op); } namespace { class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; struct InterpretedOps { SmallVector copyLikeOps; SmallVector viewLikeOps; SmallVector overwriteTensorContentsOps; }; // Check that graph rewriting is possible by doing an abstract // interpretation within a single basic block. If rewriting is // possible, the interpreted ops are returned split into their // respective categories. static FailureOr abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor, SmallVector nonValueTensorUsers, PatternRewriter &rewriter) { // Sort by order in the block, so we can abstractly interpret the ops. llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) { return lhs->isBeforeInBlock(rhs); }); // We track the available aliases at each point as well as split the // users into view-like, copy-to-value, and overwrite ops as we walk // forward. InterpretedOps result; result.copyLikeOps.push_back(copyToNonValueTensor); DenseSet availableAliases{copyToNonValueTensor.result()}; for (Operation *user : nonValueTensorUsers) { if (isViewLikeOp(user)) { Value operand = user->getOperand(0); if (!availableAliases.contains(operand)) { return rewriter.notifyMatchFailure( copyToNonValueTensor, "operand of view-like op is not a valid tensor alias"); } // View-like ops produce a new alias available to later ops. availableAliases.insert(user->getResult(0)); result.viewLikeOps.push_back(user); } else if (auto copyToValueTensor = dyn_cast(user)) { if (!availableAliases.contains(copyToValueTensor.operand())) { return rewriter.notifyMatchFailure( copyToNonValueTensor, "operand of copyToValueTensorOp is not a valid tensor alias"); } result.copyLikeOps.push_back(copyToValueTensor); } else if (auto overwrite = dyn_cast(user)) { Value overwritten = overwrite.overwritten(); if (!availableAliases.contains(overwritten)) { return rewriter.notifyMatchFailure( copyToNonValueTensor, "overwritten tensor is not a valid alias"); } // To simplify the analysis, we only support the case where the // only aliases used after an overwrite are the aliases generated // after plus the alias being overwritten. availableAliases.clear(); availableAliases.insert(overwritten); result.overwriteTensorContentsOps.push_back(overwrite); } else { return rewriter.notifyMatchFailure( copyToNonValueTensor, "unsupported op encountered during abstract analysis"); } } return result; } // Rewrite slice composed of the interpreted ops so that the slice uses // value semantics everywhere. static void rewriteSlice(const InterpretedOps &ops, PatternRewriter &rewriter) { // The rewriting for the overwrite op involves replacing all uses of its // non-value tensor operand with its value tensor operand. Since the // rewriting of other ops can potentially change the non-value tensor // operand to a value tensor, this rewriting MUST happen first to avoid // wrongly replacing operands that were previously not a view of the // overwritten tensor. for (OverwriteTensorContentsOp overwrite : llvm::reverse(ops.overwriteTensorContentsOps)) { Value overwritten = overwrite.overwritten(); assert(overwritten.getType().dyn_cast() && "the analysis assumes that overwritten remains a nonValueTensor " "throughout the rewriting"); overwritten.replaceUsesWithIf( overwrite.value(), [&](const OpOperand &operand) { return !operand.getOwner()->isBeforeInBlock(overwrite); }); rewriter.eraseOp(overwrite); } for (Operation *copyLikeOp : ops.copyLikeOps) rewriter.replaceOp(copyLikeOp, copyLikeOp->getOperand(0)); // Replace return type of view-like ops with value-semantics type variant. for (Operation *viewLikeOp : ops.viewLikeOps) { rewriter.updateRootInPlace(viewLikeOp, [&] { Value result = viewLikeOp->getResult(0); auto resultType = result.getType().dyn_cast(); assert(resultType && "all view-like ops considered must have result of " "type `NonValueTensorType` before rewriting"); result.setType(resultType.getWithValueSemantics()); }); } } LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy, PatternRewriter &rewriter) const override { // Find a subgraph starting with this CopyToNonValueTensorOp, and // terminating at CopyToValueTensorOp's, possibly with intervening view-like // ops and overwrites. This also catches the special case of a // CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's. SmallVector nonValueTensorUsers; auto workList = llvm::to_vector(copy.result().getUsers()); while (!workList.empty()) { Operation *op = workList.pop_back_val(); if (op->getBlock() != copy->getBlock()) { return rewriter.notifyMatchFailure( copy, "can only analyze within a single basic block"); } nonValueTensorUsers.push_back(op); if (isViewLikeOp(op)) { auto isTensor = [](const Value operand) { return operand.getType().isa(); }; // We currently only support view-like ops with one tensor input and one // tensor output, meaning that the tensor use-def chains form a tree. // This will not be the case for an op like `torch.aten.view_as`, so // we will need to add a set to prune duplicate visitation. if (llvm::count_if(op->getOperands(), isTensor) != 1 || llvm::count_if(op->getResults(), isTensor) != 1 || !isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) { return rewriter.notifyMatchFailure( copy, "unsupported: view-like ops must have one tensor input and " "one tensor output, and the tensor input/output must be " "the first operand/result"); } llvm::append_range(workList, op->getResult(0).getUsers()); } } FailureOr interpretedOps = abstractlyInterpretSlice( copy, std::move(nonValueTensorUsers), rewriter); if (failed(LogicalResult(interpretedOps))) return failure(); rewriteSlice(*interpretedOps, rewriter); return success(); } }; } // namespace namespace { // Calculate a forward slice starting from a CopyToNonValueTensorOp // and ending at CopyToValueTensorOp's. If all intervening ops // are just view-like operations (i.e. no mutation), then we can trivially // convert them all to value semantics. // This pattern handles the case where views span multiple basic blocks, // which is currently not supported by // `AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock`. class RewriteViewLikeSubgraph : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy, PatternRewriter &rewriter) const override { // Find a subgraph starting with this CopyToNonValueTensorOp, and // terminating at CopyToValueTensorOp's, possibly with intervening view-like // ops. // This also catches the special case of a CopyToNonValueTensorOp that // trivially feeds into CopyToValueTensorOp's. SmallVector viewLikeOps; SmallVector copyToValueTensorOps; auto workList = llvm::to_vector<6>(copy.getResult().getUsers()); // We currently only support view-like ops with one tensor input and one // tensor output, meaning that the tensor use-def chains form a tree. // This will not be the case for an op like `torch.aten.view_as`, so // we will need to add a set to prune duplicate visitation. while (!workList.empty()) { Operation *op = workList.pop_back_val(); if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); } else if (isViewLikeOp(op)) { viewLikeOps.push_back(op); llvm::append_range(workList, op->getResult(0).getUsers()); } else { return rewriter.notifyMatchFailure( copy, "can only handle these transitive user ops"); } } copy.replaceAllUsesWith(copy.getOperand()); for (CopyToValueTensorOp op : copyToValueTensorOps) rewriter.replaceOp(op, op.getOperand()); for (Operation *op : viewLikeOps) { rewriter.updateRootInPlace(op, [&]() { if (auto nonValueTensorType = op->getResult(0).getType().dyn_cast()) { op->getResult(0).setType(nonValueTensorType.getWithValueSemantics()); } }); } return success(); } }; } // namespace namespace { class MaximizeValueSemanticsPass : public MaximizeValueSemanticsBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto func = getOperation(); RewritePatternSet patterns(context); patterns.insert(context); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; } // namespace std::unique_ptr> mlir::torch::Torch::createMaximizeValueSemanticsPass() { return std::make_unique(); }