diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 77d9beb15..1ce6d2b86 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -884,8 +884,10 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ let verifier = "return ::verify(*this);"; } -def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [ - AllowsTypeRefinement +def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [ + TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type", + "value", "overwritten", + "$_self.cast().getWithoutValueSemantics()"> ]> { let summary = "Ovewrite the contents of tensor with values from another."; let description = [{ @@ -895,10 +897,12 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [ Immediately after this op has completed, indexing `overwritten` will result in identical values as indexing into `value`. Of course, later ops might mutate `overwritten`, so this relationship need not hold for the - entire program. + entire program. This op only updates the tensor data (not metadata). + In other words, it cannot change the (dynamic) shape of the overwritten tensor. - This op has undefined behavior if the two tensors have different - shapes or dtypes. + This op does not have the AllowsTypeRefinement trait because the types of the + two operands are coupled. Only places that know how to simultaneously update + both types should be changing the type of this op. }]; let arguments = (ins Torch_ValueTensorType:$value, diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 8c237d235..2aacf1d9e 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -47,7 +47,7 @@ public: if (user->getBlock() != copy->getBlock()) return failure(); // We can only analyze these ops or view-like ops. - if (isa(user)) + if (isa(user)) foundNonViewLikeOpUser = true; else if (!isViewLikeOp(user)) return failure(); @@ -71,9 +71,10 @@ public: for (Operation *user : users) { if (auto copyToValueTensor = dyn_cast(user)) { rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor}); - } else if (auto overwriteTensor = dyn_cast(user)) { - currentlyHeldValueTensor = overwriteTensor.value(); - rewriter.eraseOp(overwriteTensor); + } else if (auto overwriteTensorContents = + dyn_cast(user)) { + currentlyHeldValueTensor = overwriteTensorContents.value(); + rewriter.eraseOp(overwriteTensorContents); } else if (isViewLikeOp(user)) { // This case currently only handles view-like ops that have one tensor // input and one tensor output. diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 6a20d455f..79fbc7ad0 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -18,6 +18,24 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +// Create an overwrite in a manner that preserves the +// `OverwriteTensorContentsOp` invariant that both arguments +// must have the same shape and dtype. +static void createOverwriteTensorContents(PatternRewriter &rewriter, + Location loc, Value overwriterTensor, + Value overwrittenTensor) { + Type overwriterTensorType = overwriterTensor.getType(); + Type overwrittenTensorType = overwrittenTensor.getType() + .dyn_cast() + .getWithValueSemantics(); + if (overwriterTensorType != overwrittenTensorType) { + overwriterTensor = rewriter.create( + loc, overwrittenTensorType, overwriterTensor); + } + rewriter.create(loc, overwriterTensor, + overwrittenTensor); +} + namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. @@ -143,7 +161,7 @@ public: auto tensor = rewriter.create(loc, newOp->getResult(0)); - rewriter.create(loc, tensor, op->getOperand(0)); + createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); } @@ -180,7 +198,8 @@ public: Operation *newOp = rewriter.createOperation(state); auto tensor = rewriter.create(op->getLoc(), newOp->getResult(0)); - rewriter.create(op->getLoc(), tensor, op->getOperand(0)); + createOverwriteTensorContents(rewriter, op->getLoc(), tensor, + op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 7ebdc35cf..5a31584c8 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -2105,7 +2105,44 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) { if (isSafeToRefineOperandInPlace(use, refinedType)) { use->set(newTypedValue); continue; + } else if (auto overwriteTensorContents = + dyn_cast( + use->getOwner())) { + // `OverwriteTensorContentsOp` has special handling here because + // it requires that both of its operands always have the same + // shape and dtype. + // + // WARNING: In order to simplify the implementation, the type + // used for both operands is the type of the overwritten tensor. + // A better way of doing this would be to join the two operand + // types to create the most specific type possible and use that + // for both arguments, allowing static sizes to always propagate. + const unsigned overwriterOperandIndex = 0; + const unsigned overwrittenOperandIndex = 1; + unsigned operandNumber = use->getOperandNumber(); + if (operandNumber != overwrittenOperandIndex) + continue; + + Location loc = overwriteTensorContents.getLoc(); + Value overwriterTensor = overwriteTensorContents.value(); + Type overwriterTensorType = overwriterTensor.getType(); + Type overwrittenTensorType = newTypedValue.getType() + .dyn_cast() + .getWithValueSemantics(); + if (overwriterTensorType == overwrittenTensorType) + continue; + + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(overwriteTensorContents); + Value castedOverwriterTensor = b.create( + loc, overwrittenTensorType, overwriterTensor); + overwriteTensorContents.setOperand(overwriterOperandIndex, + castedOverwriterTensor); + } + continue; } + // If needed, create a value of the original type to appease users // that cannot accept the new type. if (!oldTypedValue) { diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index fa63c78c3..f0003eb45 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -169,3 +169,13 @@ builtin.func @torch.prim.ListConstruct() { torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list return } + +// ----- + +builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32> + // expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}} + torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32> + %1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32> + return %1 : !torch.vtensor<[1],f32> +} diff --git a/test/Dialect/Torch/maximize-value-semantics.mlir b/test/Dialect/Torch/maximize-value-semantics.mlir index 56b4d45e5..2000dc797 100644 --- a/test/Dialect/Torch/maximize-value-semantics.mlir +++ b/test/Dialect/Torch/maximize-value-semantics.mlir @@ -17,7 +17,7 @@ func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch. func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { %0 = torch.copy.to_tensor %arg0 : !torch.tensor %equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor - torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor %equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor } @@ -34,12 +34,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor // Overwrite with %arg1 - torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor %equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor // Overwrite with %arg2 - torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor %equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor @@ -52,7 +52,7 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list) -> !torch.vtensor { %t = torch.copy.to_tensor %value_t : !torch.tensor - torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor %view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list -> !torch.tensor %result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list -> !torch.tensor %value_result = torch.copy.to_vtensor %result : !torch.vtensor @@ -60,10 +60,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: } // CHECK-LABEL: func @unmodeled_mutation( -// CHECK: torch.overwrite.tensor +// CHECK: torch.overwrite.tensor.contents func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { %0 = torch.copy.to_tensor %arg0 : !torch.tensor - torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor "some.op"(%0) : (!torch.tensor) -> () %result = torch.copy.to_vtensor %0 : !torch.vtensor return %result : !torch.vtensor @@ -76,7 +76,7 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, % %tensor = torch.copy.to_tensor %arg0 : !torch.tensor %equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor torch.prim.If %cond -> () { - torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor + torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor torch.prim.If.yield } else { torch.prim.If.yield diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index 136022201..3168941cf 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -95,7 +95,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor, // being applied in sequence. // CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32> // CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32> -// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32> +// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { %c1 = torch.constant.int 1 @@ -138,7 +138,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f // CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor -// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor +// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: return %[[T]] : !torch.tensor func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor { %ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor @@ -153,7 +153,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl // CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor -// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor +// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: return %[[T]] : !torch.tensor func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { %generator = torch.constant.none @@ -169,7 +169,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { // CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor -// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor +// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: return %[[T]] : !torch.tensor func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor { %value = torch.constant.int 1 diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 84df11b2e..625113a3b 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -1157,3 +1157,35 @@ func @torch.aten.BinaryBroadcasting(%arg0: !torch.vtensor<[5,4,3,3,1],f32>, %arg %0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor return %0 : !torch.tensor } + +// ----- +// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static( +// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32> +// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32>, !torch.tensor<[2],f32> +func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { + %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor + %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor + %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor + torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor + %static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor + %result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32> + return %result : !torch.vtensor<[2],f32> +} + +// ----- +// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic( +// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32> to !torch.vtensor<[?],f32> +// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32>, !torch.tensor<[?],f32> +func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor + %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor + %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor + torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor + %dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor + %result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32> + return %result : !torch.vtensor<[?],f32> +}