mirror of https://github.com/llvm/torch-mlir
Combine maximize-value-semantics rewrite patterns into one pattern (#642)
This commit replaces the two rewrite patterns of maximize-value-semantics with a single pattern that captures the behavior of both as well as other edge cases previously not supported. The new pattern works by first performing alias analysis on a subgraph to see if pattern is applicable, then rewriting all non-value tensors to value tensors in a single go.pull/654/head snapshot-20220310.316
parent
3510b2ba9d
commit
51e267aa37
|
@ -37,68 +37,154 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
|||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Operation *> users;
|
||||
bool foundNonViewLikeOpUser = false;
|
||||
// See if our limited form of analysis is even applicatble.
|
||||
for (Operation *user : copy.getResult().getUsers()) {
|
||||
// We can only analyze within a single basic block.
|
||||
if (user->getBlock() != copy->getBlock())
|
||||
return failure();
|
||||
// We can only analyze these ops or view-like ops.
|
||||
if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
|
||||
foundNonViewLikeOpUser = true;
|
||||
else if (!isViewLikeOp(user))
|
||||
return failure();
|
||||
users.push_back(user);
|
||||
}
|
||||
|
||||
// If all users found are view-like ops, then there is nothing to do
|
||||
// here. The `RewriteViewLikeSubgraph` will take care of turning
|
||||
// these ops into ops with value semantics.
|
||||
if (!foundNonViewLikeOpUser)
|
||||
return failure();
|
||||
struct InterpretedOps {
|
||||
SmallVector<Operation *> copyLikeOps;
|
||||
SmallVector<Operation *> viewLikeOps;
|
||||
SmallVector<OverwriteTensorContentsOp> overwriteTensorContentsOps;
|
||||
};
|
||||
|
||||
// Check that graph rewriting is possible by doing an abstract
|
||||
// interpretation within a single basic block. If rewriting is
|
||||
// possible, the interpreted ops are returned split into their
|
||||
// respective categories.
|
||||
static FailureOr<InterpretedOps>
|
||||
abstractlyInterpretSlice(CopyToNonValueTensorOp copyToNonValueTensor,
|
||||
SmallVector<Operation *> nonValueTensorUsers,
|
||||
PatternRewriter &rewriter) {
|
||||
// Sort by order in the block, so we can abstractly interpret the ops.
|
||||
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
||||
llvm::sort(nonValueTensorUsers, [](Operation *lhs, Operation *rhs) {
|
||||
return lhs->isBeforeInBlock(rhs);
|
||||
});
|
||||
// Do an abstract interpretation within the block.
|
||||
// We track the current value tensor that holds the same contents as the
|
||||
// non-value tensor at each program point as we walk forward.
|
||||
Value currentlyHeldValueTensor = copy.getOperand();
|
||||
for (Operation *user : users) {
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
||||
} else if (auto overwriteTensorContents =
|
||||
dyn_cast<OverwriteTensorContentsOp>(user)) {
|
||||
currentlyHeldValueTensor = overwriteTensorContents.value();
|
||||
rewriter.eraseOp(overwriteTensorContents);
|
||||
} else if (isViewLikeOp(user)) {
|
||||
// This case currently only handles view-like ops that have one tensor
|
||||
// input and one tensor output.
|
||||
//
|
||||
// The goal here is to transform view-like ops that depend on an
|
||||
// overwritten tensor into ops that don't, so that the `overwrite` op
|
||||
// can be removed.
|
||||
Location loc = user->getLoc();
|
||||
Type currentlyHeldValueType = currentlyHeldValueTensor.getType()
|
||||
.dyn_cast<ValueTensorType>()
|
||||
.getWithoutValueSemantics();
|
||||
|
||||
{
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(user);
|
||||
Value newInput = rewriter.create<CopyToNonValueTensorOp>(
|
||||
loc, currentlyHeldValueType, currentlyHeldValueTensor);
|
||||
user->setOperands(/*start*/0, /*length*/1, {newInput});
|
||||
// We track the available aliases at each point as well as split the
|
||||
// users into view-like, copy-to-value, and overwrite ops as we walk
|
||||
// forward.
|
||||
InterpretedOps result;
|
||||
result.copyLikeOps.push_back(copyToNonValueTensor);
|
||||
DenseSet<Value> availableAliases{copyToNonValueTensor.result()};
|
||||
for (Operation *user : nonValueTensorUsers) {
|
||||
if (isViewLikeOp(user)) {
|
||||
Value operand = user->getOperand(0);
|
||||
if (!availableAliases.contains(operand)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copyToNonValueTensor,
|
||||
"operand of view-like op is not a valid tensor alias");
|
||||
}
|
||||
|
||||
// View-like ops produce a new alias available to later ops.
|
||||
availableAliases.insert(user->getResult(0));
|
||||
result.viewLikeOps.push_back(user);
|
||||
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||
if (!availableAliases.contains(copyToValueTensor.operand())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copyToNonValueTensor,
|
||||
"operand of copyToValueTensorOp is not a valid tensor alias");
|
||||
}
|
||||
result.copyLikeOps.push_back(copyToValueTensor);
|
||||
} else if (auto overwrite = dyn_cast<OverwriteTensorContentsOp>(user)) {
|
||||
Value overwritten = overwrite.overwritten();
|
||||
if (!availableAliases.contains(overwritten)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copyToNonValueTensor, "overwritten tensor is not a valid alias");
|
||||
}
|
||||
|
||||
// To simplify the analysis, we only support the case where the
|
||||
// only aliases used after an overwrite are the aliases generated
|
||||
// after plus the alias being overwritten.
|
||||
availableAliases.clear();
|
||||
availableAliases.insert(overwritten);
|
||||
result.overwriteTensorContentsOps.push_back(overwrite);
|
||||
} else {
|
||||
llvm_unreachable("only those ops supported!");
|
||||
return rewriter.notifyMatchFailure(
|
||||
copyToNonValueTensor,
|
||||
"unsupported op encountered during abstract analysis");
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(copy);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Rewrite slice composed of the interpreted ops so that the slice uses
|
||||
// value semantics everywhere.
|
||||
static void rewriteSlice(const InterpretedOps &ops,
|
||||
PatternRewriter &rewriter) {
|
||||
// The rewriting for the overwrite op involves replacing all uses of its
|
||||
// non-value tensor operand with its value tensor operand. Since the
|
||||
// rewriting of other ops can potentially change the non-value tensor
|
||||
// operand to a value tensor, this rewriting MUST happen first to avoid
|
||||
// wrongly replacing operands that were previously not a view of the
|
||||
// overwritten tensor.
|
||||
for (OverwriteTensorContentsOp overwrite :
|
||||
llvm::reverse(ops.overwriteTensorContentsOps)) {
|
||||
Value overwritten = overwrite.overwritten();
|
||||
assert(overwritten.getType().dyn_cast<NonValueTensorType>() &&
|
||||
"the analysis assumes that overwritten remains a nonValueTensor "
|
||||
"throughout the rewriting");
|
||||
overwritten.replaceUsesWithIf(
|
||||
overwrite.value(), [&](const OpOperand &operand) {
|
||||
return !operand.getOwner()->isBeforeInBlock(overwrite);
|
||||
});
|
||||
rewriter.eraseOp(overwrite);
|
||||
}
|
||||
|
||||
for (Operation *copyLikeOp : ops.copyLikeOps)
|
||||
rewriter.replaceOp(copyLikeOp, copyLikeOp->getOperand(0));
|
||||
|
||||
// Replace return type of view-like ops with value-semantics type variant.
|
||||
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
||||
rewriter.updateRootInPlace(viewLikeOp, [&] {
|
||||
Value result = viewLikeOp->getResult(0);
|
||||
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
||||
assert(resultType && "all view-like ops considered must have result of "
|
||||
"type `NonValueTensorType` before rewriting");
|
||||
result.setType(resultType.getWithValueSemantics());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find a subgraph starting with this CopyToNonValueTensorOp, and
|
||||
// terminating at CopyToValueTensorOp's, possibly with intervening view-like
|
||||
// ops and overwrites. This also catches the special case of a
|
||||
// CopyToNonValueTensorOp that trivially feeds into CopyToValueTensorOp's.
|
||||
SmallVector<Operation *> nonValueTensorUsers;
|
||||
auto workList = llvm::to_vector(copy.result().getUsers());
|
||||
while (!workList.empty()) {
|
||||
Operation *op = workList.pop_back_val();
|
||||
if (op->getBlock() != copy->getBlock()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copy, "can only analyze within a single basic block");
|
||||
}
|
||||
nonValueTensorUsers.push_back(op);
|
||||
|
||||
if (isViewLikeOp(op)) {
|
||||
auto isTensor = [](const Value operand) {
|
||||
return operand.getType().isa<BaseTensorType>();
|
||||
};
|
||||
|
||||
// We currently only support view-like ops with one tensor input and one
|
||||
// tensor output, meaning that the tensor use-def chains form a tree.
|
||||
// This will not be the case for an op like `torch.aten.view_as`, so
|
||||
// we will need to add a set to prune duplicate visitation.
|
||||
if (llvm::count_if(op->getOperands(), isTensor) != 1 ||
|
||||
llvm::count_if(op->getResults(), isTensor) != 1 ||
|
||||
!isTensor(op->getOperand(0)) || !isTensor(op->getResult(0))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copy, "unsupported: view-like ops must have one tensor input and "
|
||||
"one tensor output, and the tensor input/output must be "
|
||||
"the first operand/result");
|
||||
}
|
||||
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<InterpretedOps> interpretedOps = abstractlyInterpretSlice(
|
||||
copy, std::move(nonValueTensorUsers), rewriter);
|
||||
if (failed(LogicalResult(interpretedOps)))
|
||||
return failure();
|
||||
rewriteSlice(*interpretedOps, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -109,6 +195,9 @@ namespace {
|
|||
// and ending at CopyToValueTensorOp's. If all intervening ops
|
||||
// are just view-like operations (i.e. no mutation), then we can trivially
|
||||
// convert them all to value semantics.
|
||||
// This pattern handles the case where views span multiple basic blocks,
|
||||
// which is currently not supported by
|
||||
// `AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock`.
|
||||
class RewriteViewLikeSubgraph
|
||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||
public:
|
||||
|
@ -157,7 +246,6 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class MaximizeValueSemanticsPass
|
||||
: public MaximizeValueSemanticsBase<MaximizeValueSemanticsPass> {
|
||||
void runOnOperation() override {
|
||||
|
|
|
@ -59,6 +59,28 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
|
|||
return %value_result : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mutation_of_view_like_op_result(
|
||||
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||
// CHECK: return %[[OVERWRITER]] : !torch.vtensor
|
||||
func @mutation_of_view_like_op_result(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||
torch.overwrite.tensor.contents %overwriter overwrites %view : !torch.vtensor, !torch.tensor
|
||||
%result = torch.copy.to_vtensor %view : !torch.vtensor
|
||||
return %result : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @value_tensor_used_after_copy_was_mutated(
|
||||
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[OVERWRITER:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: return %[[VALUE_T]], %[[OVERWRITER]] : !torch.vtensor, !torch.vtensor
|
||||
func @value_tensor_used_after_copy_was_mutated(%value_t: !torch.vtensor, %overwriter: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
|
||||
%value_mutated_t = torch.copy.to_vtensor %t : !torch.vtensor
|
||||
return %value_t, %value_mutated_t : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unmodeled_mutation(
|
||||
// CHECK: torch.overwrite.tensor.contents
|
||||
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
|
@ -85,6 +107,15 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
|
|||
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @non_value_tensor_returned(
|
||||
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor) -> !torch.tensor {
|
||||
// CHECK: %[[T:.*]] = torch.copy.to_tensor %[[VALUE_T]] : !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @non_value_tensor_returned(%value_t: !torch.vtensor) -> !torch.tensor {
|
||||
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||
return %t : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @viewlike$basic_unsqueeze(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue