Combine maximize-value-semantics rewrite patterns into one pattern (#642)

This commit replaces the two rewrite patterns of
maximize-value-semantics with a single pattern that captures the
behavior of both as well as other edge cases previously not
supported. The new pattern works by first performing alias analysis on
a subgraph to see if pattern is applicable, then rewriting all
non-value tensors to value tensors in a single go.
pull/654/head snapshot-20220310.316
Ramiro Leal-Cavazos 2022-03-10 09:36:52 -08:00 committed by GitHub
parent 3510b2ba9d
commit 51e267aa37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 172 additions and 53 deletions

View File

@ -37,68 +37,154 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
bool foundNonViewLikeOpUser = false;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
// We can only analyze within a single basic block.
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops or view-like ops.
if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
foundNonViewLikeOpUser = true;
else if (!isViewLikeOp(user))
return failure();
users.push_back(user);
}
// If all users found are view-like ops, then there is nothing to do
// here. The `RewriteViewLikeSubgraph` will take care of turning
// these ops into ops with value semantics.
if (!foundNonViewLikeOpUser)
return failure();
struct InterpretedOps {
SmallVector<Operation *> copyLikeOps;
SmallVector<Operation *> viewLikeOps;
SmallVector<OverwriteTensorContentsOp> overwriteTensorContentsOps;
};
// Check that graph rewriting is possible by doing an abstract
// interpretation within a single basic block. If rewriting is
// possible, the interpreted ops are returned split into their
// respective categories.
static FailureOr<InterpretedOps>
abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor,
SmallVector<Operation *> nonValueTensorUsers,
PatternRewriter &rewriter) {
// Sort by order in the block, so we can abstractly interpret the ops.
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) {
return lhs->isBeforeInBlock(rhs);
});
// Do an abstract interpretation within the block.
// We track the current value tensor that holds the same contents as the
// non-value tensor at each program point as we walk forward.
Value currentlyHeldValueTensor = copy.getOperand();
for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(user)) {
currentlyHeldValueTensor = overwriteTensorContents.value();
rewriter.eraseOp(overwriteTensorContents);
} else if (isViewLikeOp(user)) {
// This case currently only handles view-like ops that have one tensor
// input and one tensor output.
//
// The goal here is to transform view-like ops that depend on an
// overwritten tensor into ops that don't, so that the `overwrite` op
// can be removed.
Location loc = user->getLoc();
Type currentlyHeldValueType = currentlyHeldValueTensor.getType()
.dyn_cast<ValueTensorType>()
.getWithoutValueSemantics();
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(user);
Value newInput = rewriter.create<CopyToNonValueTensorOp>(
loc, currentlyHeldValueType, currentlyHeldValueTensor);
user->setOperands(/*start*/0, /*length*/1, {newInput});
// We track the available aliases at each point as well as split the
// users into view-like, copy-to-value, and overwrite ops as we walk
// forward.
InterpretedOps result;
result.copyLikeOps.push_back(copyToNonValueTensor);
DenseSet<Value> availableAliases{copyToNonValueTensor.result()};
for (Operation *user : nonValueTensorUsers) {
if (isViewLikeOp(user)) {
Value operand = user->getOperand(0);
if (!availableAliases.contains(operand)) {
return rewriter.notifyMatchFailure(
copyToNonValueTensor,
"operand of view-like op is not a valid tensor alias");
}
// View-like ops produce a new alias available to later ops.
availableAliases.insert(user->getResult(0));
result.viewLikeOps.push_back(user);
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
if (!availableAliases.contains(copyToValueTensor.operand())) {
return rewriter.notifyMatchFailure(
copyToNonValueTensor,
"operand of copyToValueTensorOp is not a valid tensor alias");
}
result.copyLikeOps.push_back(copyToValueTensor);
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
Value overwritten = overwrite.overwritten();
if (!availableAliases.contains(overwritten)) {
return rewriter.notifyMatchFailure(
copyToNonValueTensor, "overwritten tensor is not a valid alias");
}
// To simplify the analysis, we only support the case where the
// only aliases used after an overwrite are the aliases generated
// after plus the alias being overwritten.
availableAliases.clear();
availableAliases.insert(overwritten);
result.overwriteTensorContentsOps.push_back(overwrite);
} else {
llvm_unreachable("only those ops supported!");
return rewriter.notifyMatchFailure(
copyToNonValueTensor,
"unsupported op encountered during abstract analysis");
}
}
rewriter.eraseOp(copy);
return result;
}
// Rewrite slice composed of the interpreted ops so that the slice uses
// value semantics everywhere.
static void rewriteSlice(const InterpretedOps &ops,
PatternRewriter &rewriter) {
// The rewriting for the overwrite op involves replacing all uses of its
// non-value tensor operand with its value tensor operand. Since the
// rewriting of other ops can potentially change the non-value tensor
// operand to a value tensor, this rewriting MUST happen first to avoid
// wrongly replacing operands that were previously not a view of the
// overwritten tensor.
for (OverwriteTensorContentsOp overwrite :
llvm::reverse(ops.overwriteTensorContentsOps)) {
Value overwritten = overwrite.overwritten();
assert(overwritten.getType().dyn_cast<NonValueTensorType>() &&
"the analysis assumes that overwritten remains a nonValueTensor "
"throughout the rewriting");
overwritten.replaceUsesWithIf(
overwrite.value(), [&](const OpOperand &operand) {
return !operand.getOwner()->isBeforeInBlock(overwrite);
});
rewriter.eraseOp(overwrite);
}
for (Operation *copyLikeOp : ops.copyLikeOps)
rewriter.replaceOp(copyLikeOp, copyLikeOp->getOperand(0));
// Replace return type of view-like ops with value-semantics type variant.
for (Operation *viewLikeOp : ops.viewLikeOps) {
rewriter.updateRootInPlace(viewLikeOp, [&] {
Value result = viewLikeOp->getResult(0);
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
assert(resultType && "all view-like ops considered must have result of "
"type `NonValueTensorType` before rewriting");
result.setType(resultType.getWithValueSemantics());
});
}
}
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
// Find a subgraph starting with this CopyToNonValueTensorOp, and
// terminating at CopyToValueTensorOp's, possibly with intervening view-like
// ops and overwrites. This also catches the special case of a
// CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's.
SmallVector<Operation *> nonValueTensorUsers;
auto workList = llvm::to_vector(copy.result().getUsers());
while (!workList.empty()) {
Operation *op = workList.pop_back_val();
if (op->getBlock() != copy->getBlock()) {
return rewriter.notifyMatchFailure(
copy, "can only analyze within a single basic block");
}
nonValueTensorUsers.push_back(op);
if (isViewLikeOp(op)) {
auto isTensor = [](const Value operand) {
return operand.getType().isa<BaseTensorType>();
};
// We currently only support view-like ops with one tensor input and one
// tensor output, meaning that the tensor use-def chains form a tree.
// This will not be the case for an op like `torch.aten.view_as`, so
// we will need to add a set to prune duplicate visitation.
if (llvm::count_if(op->getOperands(), isTensor) != 1 ||
llvm::count_if(op->getResults(), isTensor) != 1 ||
!isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) {
return rewriter.notifyMatchFailure(
copy, "unsupported: view-like ops must have one tensor input and "
"one tensor output, and the tensor input/output must be "
"the first operand/result");
}
llvm::append_range(workList, op->getResult(0).getUsers());
}
}
FailureOr<InterpretedOps> interpretedOps = abstractlyInterpretSlice(
copy, std::move(nonValueTensorUsers), rewriter);
if (failed(LogicalResult(interpretedOps)))
return failure();
rewriteSlice(*interpretedOps, rewriter);
return success();
}
};
@ -109,6 +195,9 @@ namespace {
// and ending at CopyToValueTensorOp's. If all intervening ops
// are just view-like operations (i.e. no mutation), then we can trivially
// convert them all to value semantics.
// This pattern handles the case where views span multiple basic blocks,
// which is currently not supported by
// `AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock`.
class RewriteViewLikeSubgraph
: public OpRewritePattern<CopyToNonValueTensorOp> {
public:
@ -157,7 +246,6 @@ public:
} // namespace
namespace {
class MaximizeValueSemanticsPass
: public MaximizeValueSemanticsBase<MaximizeValueSemanticsPass> {
void runOnOperation() override {

View File

@ -59,6 +59,28 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
return %value_result : !torch.vtensor
}
// CHECK-LABEL: func @mutation_of_view_like_op_result(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
// CHECK: return %[[OVERWRITER]] : !torch.vtensor
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor
%result = torch.copy.to_vtensor %view : !torch.vtensor
return %result : !torch.vtensor
}
// CHECK-LABEL: func @value_tensor_used_after_copy_was_mutated(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor,
// CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor
func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor
return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
@ -85,6 +107,15 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @non_value_tensor_returned(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
return %t : !torch.tensor
}
// CHECK-LABEL: func @viewlike$basic_unsqueeze(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0