mirror of https://github.com/llvm/torch-mlir
Fix handling of view-like ops in `maximize-value-semantics`
This commit adds handling to the `maximize-value-semantics` pass for the case where a view-like op depends on a tensor that has been overwritten by a value tensor. The approach for removing the dependency is to change the input to the view-like op to be a copy of the value tensor that is being used to overwrite. This commit also removes `AtenFill_ScalarOp` and `AtenBernoulli_FloatOp` from the list of view-like ops, since these ops now have a corresponding op with value semantics into which they get converted in the `reduce-op-variants` pass.bert-staging
parent
730f5915bb
commit
016f1859e0
|
@ -20,6 +20,18 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static bool isViewLikeOp(Operation *op) {
|
||||||
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
|
// correct. We could potentially be more precise and identify the cases
|
||||||
|
// that it does not return a view and treat those as having value
|
||||||
|
// semantics.
|
||||||
|
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
|
||||||
|
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
|
||||||
|
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
|
||||||
|
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
|
||||||
|
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||||
|
@ -28,16 +40,26 @@ public:
|
||||||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<Operation *> users;
|
SmallVector<Operation *> users;
|
||||||
|
bool foundNonViewLikeOpUser = false;
|
||||||
// See if our limited form of analysis is even applicatble.
|
// See if our limited form of analysis is even applicatble.
|
||||||
for (Operation *user : copy.getResult().getUsers()) {
|
for (Operation *user : copy.getResult().getUsers()) {
|
||||||
// We can only analyze within a single basic block.
|
// We can only analyze within a single basic block.
|
||||||
if (user->getBlock() != copy->getBlock())
|
if (user->getBlock() != copy->getBlock())
|
||||||
return failure();
|
return failure();
|
||||||
// We can only analyze these ops.
|
// We can only analyze these ops or view-like ops.
|
||||||
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||||
|
foundNonViewLikeOpUser = true;
|
||||||
|
else if (!isViewLikeOp(user))
|
||||||
return failure();
|
return failure();
|
||||||
users.push_back(user);
|
users.push_back(user);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If all users found are view-like ops, then there is nothing to do
|
||||||
|
// here. The `RewriteViewLikeSubgraph` will take care of turning
|
||||||
|
// these ops into ops with value semantics.
|
||||||
|
if (!foundNonViewLikeOpUser)
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Sort by order in the block, so we can abstractly interpret the ops.
|
// Sort by order in the block, so we can abstractly interpret the ops.
|
||||||
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
||||||
return lhs->isBeforeInBlock(rhs);
|
return lhs->isBeforeInBlock(rhs);
|
||||||
|
@ -70,6 +92,35 @@ public:
|
||||||
rewriter.restoreInsertionPoint(savedIP);
|
rewriter.restoreInsertionPoint(savedIP);
|
||||||
|
|
||||||
rewriter.eraseOp(overwriteTensor);
|
rewriter.eraseOp(overwriteTensor);
|
||||||
|
} else if (isViewLikeOp(user)) {
|
||||||
|
// This case currently only handles view-like ops that have one tensor
|
||||||
|
// input and one tensor output.
|
||||||
|
//
|
||||||
|
// The goal here is to transform view-like ops that depend on an
|
||||||
|
// overwritten tensor into ops that don't, so that the `overwrite` op
|
||||||
|
// can be removed. This is achieved as follows:
|
||||||
|
//
|
||||||
|
// If the view-like op has as input a tensor `T` that has been
|
||||||
|
// overwritten by a value tensor `VT`, then the input to the op
|
||||||
|
// is replaced by a `copy.to_tensor` of `VT`, removing the dependence
|
||||||
|
// on the `overwrite` op.
|
||||||
|
//
|
||||||
|
// If the view-like op has as input a non-value tensor that is a copy
|
||||||
|
// of a value tensor, all that happens is that the input gets replaced
|
||||||
|
// by a new copy of the same value tensor, so there is no net effect
|
||||||
|
// on the goal of maximizing value semantics.
|
||||||
|
Location loc = user->getLoc();
|
||||||
|
Type currentlyHeldValueType = currentlyHeldValueTensor.getType()
|
||||||
|
.dyn_cast<ValueTensorType>()
|
||||||
|
.getWithoutValueSemantics();
|
||||||
|
|
||||||
|
{
|
||||||
|
PatternRewriter::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(user);
|
||||||
|
Value newInput = rewriter.create<CopyToNonValueTensorOp>(
|
||||||
|
loc, currentlyHeldValueType, currentlyHeldValueTensor);
|
||||||
|
user->setOperands(/*start*/0, /*length*/1, {newInput});
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("only those ops supported!");
|
llvm_unreachable("only those ops supported!");
|
||||||
}
|
}
|
||||||
|
@ -107,16 +158,7 @@ public:
|
||||||
Operation *op = workList.pop_back_val();
|
Operation *op = workList.pop_back_val();
|
||||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||||
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
} else if (isViewLikeOp(op)) {
|
||||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
|
|
||||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
|
||||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
|
||||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
|
||||||
AtenTOp, AtenBernoulli_FloatOp>(op)) {
|
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
|
||||||
// correct. We could potentially be more precise and identify the cases
|
|
||||||
// that it does not return a view and treat those as having value
|
|
||||||
// semantics.
|
|
||||||
viewLikeOps.push_back(op);
|
viewLikeOps.push_back(op);
|
||||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -45,6 +45,20 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
|
||||||
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @mutation_followed_by_view_like_ops(
|
||||||
|
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||||
|
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||||
|
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||||
|
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
|
||||||
|
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||||
|
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||||
|
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
|
||||||
|
return %value_result : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @unmodeled_mutation(
|
// CHECK-LABEL: func @unmodeled_mutation(
|
||||||
// CHECK: torch.overwrite.tensor
|
// CHECK: torch.overwrite.tensor
|
||||||
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
|
Loading…
Reference in New Issue