Fix handling of view-like ops in `maximize-value-semantics`

This commit adds handling to the `maximize-value-semantics` pass for
the case where a view-like op depends on a tensor that has been
overwritten by a value tensor. The approach for removing the
dependency is to change the input to the view-like op to be a copy of
the value tensor that is being used to overwrite.

This commit also removes `AtenFill_ScalarOp` and
`AtenBernoulli_FloatOp` from the list of view-like ops, since these
ops now have a corresponding op with value semantics into which they
get converted in the `reduce-op-variants` pass.
bert-staging
Ramiro Leal-Cavazos 2022-02-18 01:52:58 +00:00
parent 730f5915bb
commit 016f1859e0
2 changed files with 68 additions and 12 deletions

View File

@ -20,6 +20,18 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static bool isViewLikeOp(Operation *op) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);
}
namespace {
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
: public OpRewritePattern<CopyToNonValueTensorOp> {
@ -28,16 +40,26 @@ public:
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
PatternRewriter &rewriter) const override {
SmallVector<Operation *> users;
bool foundNonViewLikeOpUser = false;
// See if our limited form of analysis is even applicatble.
for (Operation *user : copy.getResult().getUsers()) {
// We can only analyze within a single basic block.
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops.
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
// We can only analyze these ops or view-like ops.
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
foundNonViewLikeOpUser = true;
else if (!isViewLikeOp(user))
return failure();
users.push_back(user);
}
// If all users found are view-like ops, then there is nothing to do
// here. The `RewriteViewLikeSubgraph` will take care of turning
// these ops into ops with value semantics.
if (!foundNonViewLikeOpUser)
return failure();
// Sort by order in the block, so we can abstractly interpret the ops.
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
return lhs->isBeforeInBlock(rhs);
@ -70,6 +92,35 @@ public:
rewriter.restoreInsertionPoint(savedIP);
rewriter.eraseOp(overwriteTensor);
} else if (isViewLikeOp(user)) {
// This case currently only handles view-like ops that have one tensor
// input and one tensor output.
//
// The goal here is to transform view-like ops that depend on an
// overwritten tensor into ops that don't, so that the `overwrite` op
// can be removed. This is achieved as follows:
//
// If the view-like op has as input a tensor `T` that has been
// overwritten by a value tensor `VT`, then the input to the op
// is replaced by a `copy.to_tensor` of `VT`, removing the dependence
// on the `overwrite` op.
//
// If the view-like op has as input a non-value tensor that is a copy
// of a value tensor, all that happens is that the input gets replaced
// by a new copy of the same value tensor, so there is no net effect
// on the goal of maximizing value semantics.
Location loc = user->getLoc();
Type currentlyHeldValueType = currentlyHeldValueTensor.getType()
.dyn_cast<ValueTensorType>()
.getWithoutValueSemantics();
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(user);
Value newInput = rewriter.create<CopyToNonValueTensorOp>(
loc, currentlyHeldValueType, currentlyHeldValueTensor);
user->setOperands(/*start*/0, /*length*/1, {newInput});
}
} else {
llvm_unreachable("only those ops supported!");
}
@ -107,16 +158,7 @@ public:
Operation *op = workList.pop_back_val();
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
AtenTOp, AtenBernoulli_FloatOp>(op)) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value
// semantics.
} else if (isViewLikeOp(op)) {
viewLikeOps.push_back(op);
llvm::append_range(workList, op->getResult(0).getUsers());
} else {

View File

@ -45,6 +45,20 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
}
// CHECK-LABEL: func @mutation_followed_by_view_like_ops(
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
return %value_result : !torch.vtensor
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {