Add operand type invariant to `torch.overwrite.tensor.contents` (#606)

This commit adds the invariant to the op `torch.overwrite.tensor.contents` that
both of its operands have the same shape and size. In order to
maintain the invariant, special handling of this op is added to the
`RefineTypes` pass.
pull/619/head snapshot-20220222.284
Ramiro Leal-Cavazos 2022-02-22 11:41:46 -08:00 committed by GitHub
parent 5dbace239b
commit ba29d4f250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 125 additions and 22 deletions

View File

@ -884,8 +884,10 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
let verifier = "return ::verify(*this);";
}
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
AllowsTypeRefinement
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
]> {
let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{
@ -895,10 +897,12 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
Immediately after this op has completed, indexing `overwritten` will result
in identical values as indexing into `value`. Of course, later ops
might mutate `overwritten`, so this relationship need not hold for the
entire program.
entire program. This op only updates the tensor data (not metadata).
In other words, it cannot change the (dynamic) shape of the overwritten tensor.
This op has undefined behavior if the two tensors have different
shapes or dtypes.
This op does not have the AllowsTypeRefinement trait because the types of the
two operands are coupled. Only places that know how to simultaneously update
both types should be changing the type of this op.
}];
let arguments = (ins
Torch_ValueTensorType:$value,

View File

@ -47,7 +47,7 @@ public:
if (user->getBlock() != copy->getBlock())
return failure();
// We can only analyze these ops or view-like ops.
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
foundNonViewLikeOpUser = true;
else if (!isViewLikeOp(user))
return failure();
@ -71,9 +71,10 @@ public:
for (Operation *user : users) {
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
currentlyHeldValueTensor = overwriteTensor.value();
rewriter.eraseOp(overwriteTensor);
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(user)) {
currentlyHeldValueTensor = overwriteTensorContents.value();
rewriter.eraseOp(overwriteTensorContents);
} else if (isViewLikeOp(user)) {
// This case currently only handles view-like ops that have one tensor
// input and one tensor output.

View File

@ -18,6 +18,24 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Create an overwrite in a manner that preserves the
// `OverwriteTensorContentsOp` invariant that both arguments
// must have the same shape and dtype.
static void createOverwriteTensorContents(PatternRewriter &rewriter,
Location loc, Value overwriterTensor,
Value overwrittenTensor) {
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = overwrittenTensor.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType != overwrittenTensorType) {
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
}
rewriter.create<OverwriteTensorContentsOp>(loc, overwriterTensor,
overwrittenTensor);
}
namespace {
// Convert value semantic ops operating on mutable arrays to instead operate on
// immutable tensors.
@ -143,7 +161,7 @@ public:
auto tensor =
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
@ -180,7 +198,8 @@ public:
Operation *newOp = rewriter.createOperation(state);
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();

View File

@ -2105,7 +2105,44 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
if (isSafeToRefineOperandInPlace(use, refinedType)) {
use->set(newTypedValue);
continue;
} else if (auto overwriteTensorContents =
dyn_cast<OverwriteTensorContentsOp>(
use->getOwner())) {
// `OverwriteTensorContentsOp` has special handling here because
// it requires that both of its operands always have the same
// shape and dtype.
//
// WARNING: In order to simplify the implementation, the type
// used for both operands is the type of the overwritten tensor.
// A better way of doing this would be to join the two operand
// types to create the most specific type possible and use that
// for both arguments, allowing static sizes to always propagate.
const unsigned overwriterOperandIndex = 0;
const unsigned overwrittenOperandIndex = 1;
unsigned operandNumber = use->getOperandNumber();
if (operandNumber != overwrittenOperandIndex)
continue;
Location loc = overwriteTensorContents.getLoc();
Value overwriterTensor = overwriteTensorContents.value();
Type overwriterTensorType = overwriterTensor.getType();
Type overwrittenTensorType = newTypedValue.getType()
.dyn_cast<NonValueTensorType>()
.getWithValueSemantics();
if (overwriterTensorType == overwrittenTensorType)
continue;
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(overwriteTensorContents);
Value castedOverwriterTensor = b.create<TensorStaticInfoCastOp>(
loc, overwrittenTensorType, overwriterTensor);
overwriteTensorContents.setOperand(overwriterOperandIndex,
castedOverwriterTensor);
}
continue;
}
// If needed, create a value of the original type to appease users
// that cannot accept the new type.
if (!oldTypedValue) {

View File

@ -169,3 +169,13 @@ builtin.func @torch.prim.ListConstruct() {
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>
return
}
// -----
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
return %1 : !torch.vtensor<[1],f32>
}

View File

@ -17,7 +17,7 @@ func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
}
@ -34,12 +34,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg1
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
// Overwrite with %arg2
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
@ -52,7 +52,7 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
%t = torch.copy.to_tensor %value_t : !torch.tensor
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
@ -60,10 +60,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
}
// CHECK-LABEL: func @unmodeled_mutation(
// CHECK: torch.overwrite.tensor
// CHECK: torch.overwrite.tensor.contents
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
"some.op"(%0) : (!torch.tensor) -> ()
%result = torch.copy.to_vtensor %0 : !torch.vtensor
return %result : !torch.vtensor
@ -76,7 +76,7 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
torch.prim.If %cond -> () {
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
torch.prim.If.yield
} else {
torch.prim.If.yield

View File

@ -95,7 +95,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
// being applied in sequence.
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
%c1 = torch.constant.int 1
@ -138,7 +138,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
@ -153,7 +153,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
%generator = torch.constant.none
@ -169,7 +169,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%value = torch.constant.int 1

View File

@ -1157,3 +1157,35 @@ func @torch.aten.BinaryBroadcasting(%arg0: !torch.vtensor<[5,4,3,3,1],f32>, %arg
%0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32>, !torch.tensor<[2],f32>
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
return %result : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32> to !torch.vtensor<[?],f32>
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32>, !torch.tensor<[?],f32>
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
return %result : !torch.vtensor<[?],f32>
}