mirror of https://github.com/llvm/torch-mlir
Fix handling of view-like ops in `maximize-value-semantics`
This commit adds handling to the `maximize-value-semantics` pass for the case where a view-like op depends on a tensor that has been overwritten by a value tensor. The approach for removing the dependency is to change the input to the view-like op to be a copy of the value tensor that is being used to overwrite. This commit also removes `AtenFill_ScalarOp` and `AtenBernoulli_FloatOp` from the list of view-like ops, since these ops now have a corresponding op with value semantics into which they get converted in the `reduce-op-variants` pass.bert-staging
parent
730f5915bb
commit
016f1859e0
|
@ -20,6 +20,18 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static bool isViewLikeOp(Operation *op) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
// semantics.
|
||||
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenExpandOp,
|
||||
AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
|
||||
AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp,
|
||||
AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp,
|
||||
AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp>(op);
|
||||
}
|
||||
|
||||
namespace {
|
||||
class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock
|
||||
: public OpRewritePattern<CopyToNonValueTensorOp> {
|
||||
|
@ -28,16 +40,26 @@ public:
|
|||
LogicalResult matchAndRewrite(CopyToNonValueTensorOp copy,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Operation *> users;
|
||||
bool foundNonViewLikeOpUser = false;
|
||||
// See if our limited form of analysis is even applicatble.
|
||||
for (Operation *user : copy.getResult().getUsers()) {
|
||||
// We can only analyze within a single basic block.
|
||||
if (user->getBlock() != copy->getBlock())
|
||||
return failure();
|
||||
// We can only analyze these ops.
|
||||
if (!isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||
// We can only analyze these ops or view-like ops.
|
||||
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||
foundNonViewLikeOpUser = true;
|
||||
else if (!isViewLikeOp(user))
|
||||
return failure();
|
||||
users.push_back(user);
|
||||
}
|
||||
|
||||
// If all users found are view-like ops, then there is nothing to do
|
||||
// here. The `RewriteViewLikeSubgraph` will take care of turning
|
||||
// these ops into ops with value semantics.
|
||||
if (!foundNonViewLikeOpUser)
|
||||
return failure();
|
||||
|
||||
// Sort by order in the block, so we can abstractly interpret the ops.
|
||||
llvm::sort(users, [](Operation *lhs, Operation *rhs) {
|
||||
return lhs->isBeforeInBlock(rhs);
|
||||
|
@ -70,6 +92,35 @@ public:
|
|||
rewriter.restoreInsertionPoint(savedIP);
|
||||
|
||||
rewriter.eraseOp(overwriteTensor);
|
||||
} else if (isViewLikeOp(user)) {
|
||||
// This case currently only handles view-like ops that have one tensor
|
||||
// input and one tensor output.
|
||||
//
|
||||
// The goal here is to transform view-like ops that depend on an
|
||||
// overwritten tensor into ops that don't, so that the `overwrite` op
|
||||
// can be removed. This is achieved as follows:
|
||||
//
|
||||
// If the view-like op has as input a tensor `T` that has been
|
||||
// overwritten by a value tensor `VT`, then the input to the op
|
||||
// is replaced by a `copy.to_tensor` of `VT`, removing the dependence
|
||||
// on the `overwrite` op.
|
||||
//
|
||||
// If the view-like op has as input a non-value tensor that is a copy
|
||||
// of a value tensor, all that happens is that the input gets replaced
|
||||
// by a new copy of the same value tensor, so there is no net effect
|
||||
// on the goal of maximizing value semantics.
|
||||
Location loc = user->getLoc();
|
||||
Type currentlyHeldValueType = currentlyHeldValueTensor.getType()
|
||||
.dyn_cast<ValueTensorType>()
|
||||
.getWithoutValueSemantics();
|
||||
|
||||
{
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(user);
|
||||
Value newInput = rewriter.create<CopyToNonValueTensorOp>(
|
||||
loc, currentlyHeldValueType, currentlyHeldValueTensor);
|
||||
user->setOperands(/*start*/0, /*length*/1, {newInput});
|
||||
}
|
||||
} else {
|
||||
llvm_unreachable("only those ops supported!");
|
||||
}
|
||||
|
@ -107,16 +158,7 @@ public:
|
|||
Operation *op = workList.pop_back_val();
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
|
||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
||||
AtenTOp, AtenBernoulli_FloatOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
// semantics.
|
||||
} else if (isViewLikeOp(op)) {
|
||||
viewLikeOps.push_back(op);
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
} else {
|
||||
|
|
|
@ -45,6 +45,20 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
|
|||
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mutation_followed_by_view_like_ops(
|
||||
// CHECK-SAME: %[[VALUE_T:.*]]: !torch.vtensor, %[[OVERWRITER:.*]]: !torch.vtensor, %[[INT_LIST:.*]]: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||
// CHECK: %[[VIEW:.*]] = torch.aten.view %[[OVERWRITER]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.permute %[[VIEW]], %[[INT_LIST]] : !torch.vtensor, !torch.list<!torch.int> -> !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
|
||||
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
|
||||
return %value_result : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unmodeled_mutation(
|
||||
// CHECK: torch.overwrite.tensor
|
||||
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
|
|
Loading…
Reference in New Issue