From 730f5915bbdf00646ae12a75dcc163a1b99bf7fd Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 16 Feb 2022 23:34:45 +0000 Subject: [PATCH] Fix MaximizeValueSemantics handling of `torch.overwrite.tensor` op This commit fixes an issue with the way that the `maximize-value-semantics` pass was handling the `torch.overwrite.tensor` op. Before, the overwrite op would simply get deleted and uses of the overwritten tensor would simply get replaced with the overwriter tensor. This becomes a problem when one tensor has a dynamic shape and the other has a static shape because ops that were previously using the overwritten tensor will now get a tensor of a different type. To fix this, this commit adds a static cast to the overwriter tensor so that it has the type of the overwritten tensor, and this casted tensor is then used to replace the uses of the overwritten tensor. --- .../Transforms/MaximizeValueSemantics.cpp | 20 ++++++++++++++++++- .../Torch/maximize-value-semantics.mlir | 16 +++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 03eded344..554ef263f 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -50,7 +50,25 @@ public: if (auto copyToValueTensor = dyn_cast(user)) { rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor}); } else if (auto overwriteTensor = dyn_cast(user)) { - currentlyHeldValueTensor = overwriteTensor.value(); + Location loc = user->getLoc(); + Value overwriter = overwriteTensor.value(); + Value overwritten = overwriteTensor.overwritten(); + Type overwrittenType = overwritten.getType() + .dyn_cast() + .getWithValueSemantics(); + + // Sometimes the overwriter tensor has a different type from the + // overwritten tensor. This can happen, for example, if one tensor + // has dynamic shape and the other has a static shape. Since the goal + // of this pattern is to replace uses of the overwritten tensor with + // overwriter tensor, here we cast the overwriter to the type of the + // overwritten, to avoid type mismatches later on in the graph. + auto savedIP = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(overwriteTensor); + currentlyHeldValueTensor = rewriter.create( + loc, overwrittenType, overwriter); + rewriter.restoreInsertionPoint(savedIP); + rewriter.eraseOp(overwriteTensor); } else { llvm_unreachable("only those ops supported!"); diff --git a/test/Dialect/Torch/maximize-value-semantics.mlir b/test/Dialect/Torch/maximize-value-semantics.mlir index fb489cdfb..151944b42 100644 --- a/test/Dialect/Torch/maximize-value-semantics.mlir +++ b/test/Dialect/Torch/maximize-value-semantics.mlir @@ -55,6 +55,22 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch return %result : !torch.vtensor } +// CHECK-LABEL: func @mutation_with_tensor_of_different_type( +// CHECK-SAME: %[[T_0:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[T_1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[T_0]], %[[T_1]], %[[ONE]] : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: %[[STATIC_CAST:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32> +// CHECK: return %[[STATIC_CAST]] : !torch.vtensor<[2],f32> +func @mutation_with_tensor_of_different_type(%t0: !torch.vtensor<[2],f32>, %t1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { + %one = torch.constant.int 1 + %t0_copy = torch.copy.to_tensor %t0 : !torch.tensor<[2],f32> + %add = torch.aten.add.Tensor %t0, %t1, %one : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + torch.overwrite.tensor %add overwrites %t0_copy : !torch.vtensor<[?],f32>, !torch.tensor<[2],f32> + %t0_value_copy = torch.copy.to_vtensor %t0_copy : !torch.vtensor<[2],f32> + return %t0_value_copy : !torch.vtensor<[2],f32> +} + // We don't yet handle nontrivial cases involving control flow. // CHECK-LABEL: func @unimplemented_control_flow( // CHECK: torch.copy.to_vtensor