Fix MaximizeValueSemantics handling of `torch.overwrite.tensor` op

This commit fixes an issue with the way that the
`maximize-value-semantics` pass was handling the
`torch.overwrite.tensor` op. Before, the overwrite op would simply get
deleted and uses of the overwritten tensor would simply get replaced
with the overwriter tensor. This becomes a problem when one tensor has
a dynamic shape and the other has a static shape because ops that were
previously using the overwritten tensor will now get a tensor of a
different type.

To fix this, this commit adds a static cast to the overwriter tensor
so that it has the type of the overwritten tensor, and this casted
tensor is then used to replace the uses of the overwritten tensor.
bert-staging
Ramiro Leal-Cavazos 2022-02-16 23:34:45 +00:00
parent fa6cf0bed8
commit 730f5915bb
2 changed files with 35 additions and 1 deletions

View File

@ -50,7 +50,25 @@ public:
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
currentlyHeldValueTensor = overwriteTensor.value();
Location loc = user->getLoc();
Value overwriter = overwriteTensor.value();
Value overwritten = overwriteTensor.overwritten();
Type overwrittenType = overwritten.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
// Sometimes the overwriter tensor has a different type from the
// overwritten tensor. This can happen, for example, if one tensor
// has dynamic shape and the other has a static shape. Since the goal
// of this pattern is to replace uses of the overwritten tensor with
// overwriter tensor, here we cast the overwriter to the type of the
// overwritten, to avoid type mismatches later on in the graph.
auto savedIP = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(overwriteTensor);
currentlyHeldValueTensor = rewriter.create<TensorStaticInfoCastOp>(
loc, overwrittenType, overwriter);
rewriter.restoreInsertionPoint(savedIP);
rewriter.eraseOp(overwriteTensor);
} else {
llvm_unreachable("only those ops supported!");

View File

@ -55,6 +55,22 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch
return %result : !torch.vtensor
}
// CHECK-LABEL: func @mutation_with_tensor_of_different_type(
// CHECK-SAME: %[[T_0:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[T_1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[ONE:.*]] = torch.constant.int 1
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[T_0]], %[[T_1]], %[[ONE]] : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[STATIC_CAST:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
// CHECK: return %[[STATIC_CAST]] : !torch.vtensor<[2],f32>
func @mutation_with_tensor_of_different_type(%t0: !torch.vtensor<[2],f32>, %t1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%one = torch.constant.int 1
%t0_copy = torch.copy.to_tensor %t0 : !torch.tensor<[2],f32>
%add = torch.aten.add.Tensor %t0, %t1, %one : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
torch.overwrite.tensor %add overwrites %t0_copy : !torch.vtensor<[?],f32>, !torch.tensor<[2],f32>
%t0_value_copy = torch.copy.to_vtensor %t0_copy : !torch.vtensor<[2],f32>
return %t0_value_copy : !torch.vtensor<[2],f32>
}
// We don't yet handle nontrivial cases involving control flow.
// CHECK-LABEL: func @unimplemented_control_flow(
// CHECK: torch.copy.to_vtensor