diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 03eded344..554ef263f 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -50,7 +50,25 @@ public: if (auto copyToValueTensor = dyn_cast(user)) { rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor}); } else if (auto overwriteTensor = dyn_cast(user)) { - currentlyHeldValueTensor = overwriteTensor.value(); + Location loc = user->getLoc(); + Value overwriter = overwriteTensor.value(); + Value overwritten = overwriteTensor.overwritten(); + Type overwrittenType = overwritten.getType() + .dyn_cast() + .getWithValueSemantics(); + + // Sometimes the overwriter tensor has a different type from the + // overwritten tensor. This can happen, for example, if one tensor + // has dynamic shape and the other has a static shape. Since the goal + // of this pattern is to replace uses of the overwritten tensor with + // overwriter tensor, here we cast the overwriter to the type of the + // overwritten, to avoid type mismatches later on in the graph. + auto savedIP = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(overwriteTensor); + currentlyHeldValueTensor = rewriter.create( + loc, overwrittenType, overwriter); + rewriter.restoreInsertionPoint(savedIP); + rewriter.eraseOp(overwriteTensor); } else { llvm_unreachable("only those ops supported!"); diff --git a/test/Dialect/Torch/maximize-value-semantics.mlir b/test/Dialect/Torch/maximize-value-semantics.mlir index fb489cdfb..151944b42 100644 --- a/test/Dialect/Torch/maximize-value-semantics.mlir +++ b/test/Dialect/Torch/maximize-value-semantics.mlir @@ -55,6 +55,22 @@ func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch return %result : !torch.vtensor } +// CHECK-LABEL: func @mutation_with_tensor_of_different_type( +// CHECK-SAME: %[[T_0:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[T_1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: %[[ONE:.*]] = torch.constant.int 1 +// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[T_0]], %[[T_1]], %[[ONE]] : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: %[[STATIC_CAST:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32> +// CHECK: return %[[STATIC_CAST]] : !torch.vtensor<[2],f32> +func @mutation_with_tensor_of_different_type(%t0: !torch.vtensor<[2],f32>, %t1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { + %one = torch.constant.int 1 + %t0_copy = torch.copy.to_tensor %t0 : !torch.tensor<[2],f32> + %add = torch.aten.add.Tensor %t0, %t1, %one : !torch.vtensor<[2],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + torch.overwrite.tensor %add overwrites %t0_copy : !torch.vtensor<[?],f32>, !torch.tensor<[2],f32> + %t0_value_copy = torch.copy.to_vtensor %t0_copy : !torch.vtensor<[2],f32> + return %t0_value_copy : !torch.vtensor<[2],f32> +} + // We don't yet handle nontrivial cases involving control flow. // CHECK-LABEL: func @unimplemented_control_flow( // CHECK: torch.copy.to_vtensor