Add operand type invariant to `torch.overwrite.tensor.contents` (#606)

This commit adds the invariant to the op `torch.overwrite.tensor.contents` that
both of its operands have the same shape and size. In order to
maintain the invariant, special handling of this op is added to the
`RefineTypes` pass.
pull/619/head snapshot-20220222.284
Ramiro Leal-Cavazos 2022-02-22 11:41:46 -08:00 committed by GitHub
parent 5dbace239b
commit ba29d4f250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 125 additions and 22 deletions

View File

@ -884,8 +884,10 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
let verifier = "return ::verify(*this);"; let verifier = "return ::verify(*this);";
} }
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
AllowsTypeRefinement TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
]> { ]> {
let summary = "Ovewrite the contents of tensor with values from another."; let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{ let description = [{
@ -895,10 +897,12 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
Immediately after this op has completed, indexing `overwritten` will result Immediately after this op has completed, indexing `overwritten` will result
in identical values as indexing into `value`. Of course, later ops in identical values as indexing into `value`. Of course, later ops
might mutate `overwritten`, so this relationship need not hold for the might mutate `overwritten`, so this relationship need not hold for the
entire program. entire program. This op only updates the tensor data (not metadata).
In other words, it cannot change the (dynamic) shape of the overwritten tensor.
This op has undefined behavior if the two tensors have different This op does not have the AllowsTypeRefinement trait because the types of the
shapes or dtypes. two operands are coupled. Only places that know how to simultaneously update
both types should be changing the type of this op.
}]; }];
let arguments = (ins let arguments = (ins
Torch_ValueTensorType:$value, Torch_ValueTensorType:$value,

View File

@ -47,7 +47,7 @@ public:
if (user->getBlock() != copy->getBlock()) if (user->getBlock() != copy->getBlock())
return failure(); return failure();
// We can only analyze these ops or view-like ops. // We can only analyze these ops or view-like ops.
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user)) if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
foundNonViewLikeOpUser = true; foundNonViewLikeOpUser = true;
else if (!isViewLikeOp(user)) else if (!isViewLikeOp(user))
return failure(); return failure();
@ -71,9 +71,10 @@ public:
for (Operation *user : users) { for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) { if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor}); rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) { } else if (auto overwriteTensorContents =
currentlyHeldValueTensor = overwriteTensor.value(); dyn_cast<OverwriteTensorContentsOp>(user)) {
rewriter.eraseOp(overwriteTensor); currentlyHeldValueTensor = overwriteTensorContents.value();
rewriter.eraseOp(overwriteTensorContents);
} else if (isViewLikeOp(user)) { } else if (isViewLikeOp(user)) {
// This case currently only handles view-like ops that have one tensor // This case currently only handles view-like ops that have one tensor
// input and one tensor output. // input and one tensor output.

View File

@ -18,6 +18,24 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// Create an overwrite in a manner that preserves the
// `OverwriteTensorContentsOp` invariant that both arguments
// must have the same shape and dtype.
static void createOverwriteTensorContents(PatternRewriter &rewriter,
Location loc, Value overwriterTensor,
Value overwrittenTensor) {
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = overwrittenTensor.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType != overwrittenTensorType) {
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
}
rewriter.create<OverwriteTensorContentsOp>(loc, overwriterTensor,
overwrittenTensor);
}
namespace { namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on // Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors. // immutable tensors.
@ -143,7 +161,7 @@ public:
auto tensor = auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0)); rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0)); createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0));
return success(); return success();
} }
@ -180,7 +198,8 @@ public:
Operation *newOp = rewriter.createOperation(state); Operation *newOp = rewriter.createOperation(state);
auto tensor = auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0)); rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0)); createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0));
return success(); return success();

View File

@ -2105,7 +2105,44 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
if (isSafeToRefineOperandInPlace(use, refinedType)) { if (isSafeToRefineOperandInPlace(use, refinedType)) {
use->set(newTypedValue); use->set(newTypedValue);
continue; continue;
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(
use->getOwner())) {
// `OverwriteTensorContentsOp` has special handling here because
// it requires that both of its operands always have the same
// shape and dtype.
//
// WARNING: In order to simplify the implementation, the type
// used for both operands is the type of the overwritten tensor.
// A better way of doing this would be to join the two operand
// types to create the most specific type possible and use that
// for both arguments, allowing static sizes to always propagate.
const unsigned overwriterOperandIndex = 0;
const unsigned overwrittenOperandIndex = 1;
unsigned operandNumber = use->getOperandNumber();
if (operandNumber != overwrittenOperandIndex)
continue;
Location loc = overwriteTensorContents.getLoc();
Value overwriterTensor = overwriteTensorContents.value();
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = newTypedValue.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType == overwrittenTensorType)
continue;
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(overwriteTensorContents);
Value castedOverwriterTensor = b.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
overwriteTensorContents.setOperand(overwriterOperandIndex,
castedOverwriterTensor);
}
continue;
} }
// If needed, create a value of the original type to appease users // If needed, create a value of the original type to appease users
// that cannot accept the new type. // that cannot accept the new type.
if (!oldTypedValue) { if (!oldTypedValue) {

View File

@ -169,3 +169,13 @@ builtin.func @torch.prim.ListConstruct() {
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor> torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>
return return
} }
// -----
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
return %1 : !torch.vtensor<[1],f32>
}

View File

@ -17,7 +17,7 @@ func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor %equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
} }
@ -34,12 +34,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg1 // Overwrite with %arg1
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg2 // Overwrite with %arg2
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
@ -52,7 +52,7 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor { func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor %t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor %view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor %result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%value_result = torch.copy.to_vtensor %result : !torch.vtensor %value_result = torch.copy.to_vtensor %result : !torch.vtensor
@ -60,10 +60,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
} }
// CHECK-LABEL: func @unmodeled_mutation( // CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor // CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor %0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> () "some.op"(%0) : (!torch.tensor) -> ()
%result = torch.copy.to_vtensor %0 : !torch.vtensor %result = torch.copy.to_vtensor %0 : !torch.vtensor
return %result : !torch.vtensor return %result : !torch.vtensor
@ -76,7 +76,7 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor %tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor %equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () { torch.prim.If %cond -> () {
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.prim.If.yield torch.prim.If.yield
} else { } else {
torch.prim.If.yield torch.prim.If.yield

View File

@ -95,7 +95,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// being applied in sequence. // being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32> // CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32> // CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> // CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
%c1 = torch.constant.int 1 %c1 = torch.constant.int 1
@ -138,7 +138,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor // CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor { func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor %ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
@ -153,7 +153,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none %generator = torch.constant.none
@ -169,7 +169,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor // CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor // CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor // CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor // CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor // CHECK: return %[[T]] : !torch.tensor
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor { func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%value = torch.constant.int 1 %value = torch.constant.int 1

View File

@ -1157,3 +1157,35 @@ func @torch.aten.BinaryBroadcasting(%arg0: !torch.vtensor<[5,4,3,3,1],f32>, %arg
%0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor %0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32>, !torch.tensor<[2],f32>
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
return %result : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32> to !torch.vtensor<[?],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32>, !torch.tensor<[?],f32>
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
return %result : !torch.vtensor<[?],f32>
}