mirror of https://github.com/llvm/torch-mlir
Fix MaximizeValueSemantics handling of `torch.overwrite.tensor` op
This commit fixes an issue with the way that the `maximize-value-semantics` pass was handling the `torch.overwrite.tensor` op. Before, the overwrite op would simply get deleted and uses of the overwritten tensor would simply get replaced with the overwriter tensor. This becomes a problem when one tensor has a dynamic shape and the other has a static shape because ops that were previously using the overwritten tensor will now get a tensor of a different type. To fix this, this commit adds a static cast to the overwriter tensor so that it has the type of the overwritten tensor, and this casted tensor is then used to replace the uses of the overwritten tensor.bert-staging
parent
fa6cf0bed8
commit
730f5915bb
|
@ -50,7 +50,25 @@ public:
|
||||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||||
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
||||||
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
|
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
|
||||||
currentlyHeldValueTensor = overwriteTensor.value();
|
Location loc = user->getLoc();
|
||||||
|
Value overwriter = overwriteTensor.value();
|
||||||
|
Value overwritten = overwriteTensor.overwritten();
|
||||||
|
Type overwrittenType = overwritten.getType()
|
||||||
|
.dyn_cast<NonValueTensorType>()
|
||||||
|
.getWithValueSemantics();
|
||||||
|
|
||||||
|
// Sometimes the overwriter tensor has a different type from the
|
||||||
|
// overwritten tensor. This can happen, for example, if one tensor
|
||||||
|
// has dynamic shape and the other has a static shape. Since the goal
|
||||||
|
// of this pattern is to replace uses of the overwritten tensor with
|
||||||
|
// overwriter tensor, here we cast the overwriter to the type of the
|
||||||
|
// overwritten, to avoid type mismatches later on in the graph.
|
||||||
|
auto savedIP = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointAfter(overwriteTensor);
|
||||||
|
currentlyHeldValueTensor = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
|
loc, overwrittenType, overwriter);
|
||||||
|
rewriter.restoreInsertionPoint(savedIP);
|
||||||
|
|
||||||
rewriter.eraseOp(overwriteTensor);
|
rewriter.eraseOp(overwriteTensor);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("only those ops supported!");
|
llvm_unreachable("only those ops supported!");
|
||||||
|
|
|
@ -55,6 +55,22 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch
|
||||||
return %result : !torch.vtensor
|
return %result : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @mutation_with_tensor_of_different_type(
|
||||||
|
// CHECK-SAME: %[[T_0:.*]]: !torch.vtensor<[2],f32>,
|
||||||
|
// CHECK-SAME: %[[T_1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
// CHECK: %[[ONE:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[T_0]], %[[T_1]], %[[ONE]] : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[STATIC_CAST:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
|
||||||
|
// CHECK: return %[[STATIC_CAST]] : !torch.vtensor<[2],f32>
|
||||||
|
func @mutation_with_tensor_of_different_type(%t0: !torch.vtensor<[2],f32>, %t1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
%one = torch.constant.int 1
|
||||||
|
%t0_copy = torch.copy.to_tensor %t0 : !torch.tensor<[2],f32>
|
||||||
|
%add = torch.aten.add.Tensor %t0, %t1, %one : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||||
|
torch.overwrite.tensor %add overwrites %t0_copy : !torch.vtensor<[?],f32>, !torch.tensor<[2],f32>
|
||||||
|
%t0_value_copy = torch.copy.to_vtensor %t0_copy : !torch.vtensor<[2],f32>
|
||||||
|
return %t0_value_copy : !torch.vtensor<[2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// We don't yet handle nontrivial cases involving control flow.
|
// We don't yet handle nontrivial cases involving control flow.
|
||||||
// CHECK-LABEL: func @unimplemented_control_flow(
|
// CHECK-LABEL: func @unimplemented_control_flow(
|
||||||
// CHECK: torch.copy.to_vtensor
|
// CHECK: torch.copy.to_vtensor
|
||||||
|
|
Loading…
Reference in New Issue