From 51e267aa372c32f037fb45813fc39a5ece86e273 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 10 Mar 2022 09:36:52 -0800 Subject: [PATCH] Combine maximize-value-semantics rewrite patterns into one pattern (#642) This commit replaces the two rewrite patterns of maximize-value-semantics with a single pattern that captures the behavior of both as well as other edge cases previously not supported. The new pattern works by first performing alias analysis on a subgraph to see if pattern is applicable, then rewriting all non-value tensors to value tensors in a single go. --- .../Transforms/MaximizeValueSemantics.cpp | 194 +++++++++++++----- .../Torch/maximize-value-semantics.mlir | 31 +++ 2 files changed, 172 insertions(+), 53 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 2aacf1d9e..4af0f4baf 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -37,68 +37,154 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy, - PatternRewriter &rewriter) const override { - SmallVector users; - bool foundNonViewLikeOpUser = false; - // See if our limited form of analysis is even applicatble. - for (Operation *user : copy.getResult().getUsers()) { - // We can only analyze within a single basic block. - if (user->getBlock() != copy->getBlock()) - return failure(); - // We can only analyze these ops or view-like ops. - if (isa(user)) - foundNonViewLikeOpUser = true; - else if (!isViewLikeOp(user)) - return failure(); - users.push_back(user); - } - // If all users found are view-like ops, then there is nothing to do - // here. The `RewriteViewLikeSubgraph` will take care of turning - // these ops into ops with value semantics. - if (!foundNonViewLikeOpUser) - return failure(); + struct InterpretedOps { + SmallVector copyLikeOps; + SmallVector viewLikeOps; + SmallVector overwriteTensorContentsOps; + }; + // Check that graph rewriting is possible by doing an abstract + // interpretation within a single basic block. If rewriting is + // possible, the interpreted ops are returned split into their + // respective categories. + static FailureOr + abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor, + SmallVector nonValueTensorUsers, + PatternRewriter &rewriter) { // Sort by order in the block, so we can abstractly interpret the ops. - llvm::sort(users, [](Operation *lhs, Operation *rhs) { + llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) { return lhs->isBeforeInBlock(rhs); }); - // Do an abstract interpretation within the block. - // We track the current value tensor that holds the same contents as the - // non-value tensor at each program point as we walk forward. - Value currentlyHeldValueTensor = copy.getOperand(); - for (Operation *user : users) { - if (auto copyToValueTensor = dyn_cast(user)) { - rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor}); - } else if (auto overwriteTensorContents = - dyn_cast(user)) { - currentlyHeldValueTensor = overwriteTensorContents.value(); - rewriter.eraseOp(overwriteTensorContents); - } else if (isViewLikeOp(user)) { - // This case currently only handles view-like ops that have one tensor - // input and one tensor output. - // - // The goal here is to transform view-like ops that depend on an - // overwritten tensor into ops that don't, so that the `overwrite` op - // can be removed. - Location loc = user->getLoc(); - Type currentlyHeldValueType = currentlyHeldValueTensor.getType() - .dyn_cast() - .getWithoutValueSemantics(); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(user); - Value newInput = rewriter.create( - loc, currentlyHeldValueType, currentlyHeldValueTensor); - user->setOperands(/*start*/0, /*length*/1, {newInput}); + // We track the available aliases at each point as well as split the + // users into view-like, copy-to-value, and overwrite ops as we walk + // forward. + InterpretedOps result; + result.copyLikeOps.push_back(copyToNonValueTensor); + DenseSet availableAliases{copyToNonValueTensor.result()}; + for (Operation *user : nonValueTensorUsers) { + if (isViewLikeOp(user)) { + Value operand = user->getOperand(0); + if (!availableAliases.contains(operand)) { + return rewriter.notifyMatchFailure( + copyToNonValueTensor, + "operand of view-like op is not a valid tensor alias"); } + + // View-like ops produce a new alias available to later ops. + availableAliases.insert(user->getResult(0)); + result.viewLikeOps.push_back(user); + } else if (auto copyToValueTensor = dyn_cast(user)) { + if (!availableAliases.contains(copyToValueTensor.operand())) { + return rewriter.notifyMatchFailure( + copyToNonValueTensor, + "operand of copyToValueTensorOp is not a valid tensor alias"); + } + result.copyLikeOps.push_back(copyToValueTensor); + } else if (auto overwrite = dyn_cast(user)) { + Value overwritten = overwrite.overwritten(); + if (!availableAliases.contains(overwritten)) { + return rewriter.notifyMatchFailure( + copyToNonValueTensor, "overwritten tensor is not a valid alias"); + } + + // To simplify the analysis, we only support the case where the + // only aliases used after an overwrite are the aliases generated + // after plus the alias being overwritten. + availableAliases.clear(); + availableAliases.insert(overwritten); + result.overwriteTensorContentsOps.push_back(overwrite); } else { - llvm_unreachable("only those ops supported!"); + return rewriter.notifyMatchFailure( + copyToNonValueTensor, + "unsupported op encountered during abstract analysis"); } } - rewriter.eraseOp(copy); + return result; + } + + // Rewrite slice composed of the interpreted ops so that the slice uses + // value semantics everywhere. + static void rewriteSlice(const InterpretedOps &ops, + PatternRewriter &rewriter) { + // The rewriting for the overwrite op involves replacing all uses of its + // non-value tensor operand with its value tensor operand. Since the + // rewriting of other ops can potentially change the non-value tensor + // operand to a value tensor, this rewriting MUST happen first to avoid + // wrongly replacing operands that were previously not a view of the + // overwritten tensor. + for (OverwriteTensorContentsOp overwrite : + llvm::reverse(ops.overwriteTensorContentsOps)) { + Value overwritten = overwrite.overwritten(); + assert(overwritten.getType().dyn_cast() && + "the analysis assumes that overwritten remains a nonValueTensor " + "throughout the rewriting"); + overwritten.replaceUsesWithIf( + overwrite.value(), [&](const OpOperand &operand) { + return !operand.getOwner()->isBeforeInBlock(overwrite); + }); + rewriter.eraseOp(overwrite); + } + + for (Operation *copyLikeOp : ops.copyLikeOps) + rewriter.replaceOp(copyLikeOp, copyLikeOp->getOperand(0)); + + // Replace return type of view-like ops with value-semantics type variant. + for (Operation *viewLikeOp : ops.viewLikeOps) { + rewriter.updateRootInPlace(viewLikeOp, [&] { + Value result = viewLikeOp->getResult(0); + auto resultType = result.getType().dyn_cast(); + assert(resultType && "all view-like ops considered must have result of " + "type `NonValueTensorType` before rewriting"); + result.setType(resultType.getWithValueSemantics()); + }); + } + } + + LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy, + PatternRewriter &rewriter) const override { + // Find a subgraph starting with this CopyToNonValueTensorOp, and + // terminating at CopyToValueTensorOp's, possibly with intervening view-like + // ops and overwrites. This also catches the special case of a + // CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's. + SmallVector nonValueTensorUsers; + auto workList = llvm::to_vector(copy.result().getUsers()); + while (!workList.empty()) { + Operation *op = workList.pop_back_val(); + if (op->getBlock() != copy->getBlock()) { + return rewriter.notifyMatchFailure( + copy, "can only analyze within a single basic block"); + } + nonValueTensorUsers.push_back(op); + + if (isViewLikeOp(op)) { + auto isTensor = [](const Value operand) { + return operand.getType().isa(); + }; + + // We currently only support view-like ops with one tensor input and one + // tensor output, meaning that the tensor use-def chains form a tree. + // This will not be the case for an op like `torch.aten.view_as`, so + // we will need to add a set to prune duplicate visitation. + if (llvm::count_if(op->getOperands(), isTensor) != 1 || + llvm::count_if(op->getResults(), isTensor) != 1 || + !isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) { + return rewriter.notifyMatchFailure( + copy, "unsupported: view-like ops must have one tensor input and " + "one tensor output, and the tensor input/output must be " + "the first operand/result"); + } + + llvm::append_range(workList, op->getResult(0).getUsers()); + } + } + + FailureOr interpretedOps = abstractlyInterpretSlice( + copy, std::move(nonValueTensorUsers), rewriter); + if (failed(LogicalResult(interpretedOps))) + return failure(); + rewriteSlice(*interpretedOps, rewriter); return success(); } }; @@ -109,6 +195,9 @@ namespace { // and ending at CopyToValueTensorOp's. If all intervening ops // are just view-like operations (i.e. no mutation), then we can trivially // convert them all to value semantics. +// This pattern handles the case where views span multiple basic blocks, +// which is currently not supported by +// `AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock`. class RewriteViewLikeSubgraph : public OpRewritePattern { public: @@ -157,7 +246,6 @@ public: } // namespace namespace { - class MaximizeValueSemanticsPass : public MaximizeValueSemanticsBase { void runOnOperation() override { diff --git a/test/Dialect/Torch/maximize-value-semantics.mlir b/test/Dialect/Torch/maximize-value-semantics.mlir index 2000dc797..7e2459390 100644 --- a/test/Dialect/Torch/maximize-value-semantics.mlir +++ b/test/Dialect/Torch/maximize-value-semantics.mlir @@ -59,6 +59,28 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: return %value_result : !torch.vtensor } +// CHECK-LABEL: func @mutation_of_view_like_op_result( +// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list) -> !torch.vtensor { +// CHECK: return %[[OVERWRITER]] : !torch.vtensor +func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list) -> !torch.vtensor { + %t = torch.copy.to_tensor %value_t : !torch.tensor + %view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list -> !torch.tensor + torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor + %result = torch.copy.to_vtensor %view : !torch.vtensor + return %result : !torch.vtensor +} + +// CHECK-LABEL: func @value_tensor_used_after_copy_was_mutated( +// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, +// CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { +// CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor +func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { + %t = torch.copy.to_tensor %value_t : !torch.tensor + torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor + %value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor + return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor +} + // CHECK-LABEL: func @unmodeled_mutation( // CHECK: torch.overwrite.tensor.contents func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { @@ -85,6 +107,15 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, % return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor } +// CHECK-LABEL: func @non_value_tensor_returned( +// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor { +// CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor +// CHECK: return %[[T]] : !torch.tensor +func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor { + %t = torch.copy.to_tensor %value_t : !torch.tensor + return %t : !torch.tensor +} + // CHECK-LABEL: func @viewlike$basic_unsqueeze( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK: %[[INT0:.*]] = torch.constant.int 0