mirror of https://github.com/llvm/torch-mlir
Add operand type invariant to `torch.overwrite.tensor.contents` (#606)
This commit adds the invariant to the op `torch.overwrite.tensor.contents` that both of its operands have the same shape and size. In order to maintain the invariant, special handling of this op is added to the `RefineTypes` pass.pull/619/head snapshot-20220222.284
parent
5dbace239b
commit
ba29d4f250
|
@ -884,8 +884,10 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
|||
let verifier = "return ::verify(*this);";
|
||||
}
|
||||
|
||||
def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
||||
AllowsTypeRefinement
|
||||
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
||||
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
|
||||
"value", "overwritten",
|
||||
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
|
||||
]> {
|
||||
let summary = "Ovewrite the contents of tensor with values from another.";
|
||||
let description = [{
|
||||
|
@ -895,10 +897,12 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
|||
Immediately after this op has completed, indexing `overwritten` will result
|
||||
in identical values as indexing into `value`. Of course, later ops
|
||||
might mutate `overwritten`, so this relationship need not hold for the
|
||||
entire program.
|
||||
entire program. This op only updates the tensor data (not metadata).
|
||||
In other words, it cannot change the (dynamic) shape of the overwritten tensor.
|
||||
|
||||
This op has undefined behavior if the two tensors have different
|
||||
shapes or dtypes.
|
||||
This op does not have the AllowsTypeRefinement trait because the types of the
|
||||
two operands are coupled. Only places that know how to simultaneously update
|
||||
both types should be changing the type of this op.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Torch_ValueTensorType:$value,
|
||||
|
|
|
@ -47,7 +47,7 @@ public:
|
|||
if (user->getBlock() != copy->getBlock())
|
||||
return failure();
|
||||
// We can only analyze these ops or view-like ops.
|
||||
if (isa<CopyToValueTensorOp, OverwriteTensorOp>(user))
|
||||
if (isa<CopyToValueTensorOp, OverwriteTensorContentsOp>(user))
|
||||
foundNonViewLikeOpUser = true;
|
||||
else if (!isViewLikeOp(user))
|
||||
return failure();
|
||||
|
@ -71,9 +71,10 @@ public:
|
|||
for (Operation *user : users) {
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||
rewriter.replaceOp(copyToValueTensor, {currentlyHeldValueTensor});
|
||||
} else if (auto overwriteTensor = dyn_cast<OverwriteTensorOp>(user)) {
|
||||
currentlyHeldValueTensor = overwriteTensor.value();
|
||||
rewriter.eraseOp(overwriteTensor);
|
||||
} else if (auto overwriteTensorContents =
|
||||
dyn_cast<OverwriteTensorContentsOp>(user)) {
|
||||
currentlyHeldValueTensor = overwriteTensorContents.value();
|
||||
rewriter.eraseOp(overwriteTensorContents);
|
||||
} else if (isViewLikeOp(user)) {
|
||||
// This case currently only handles view-like ops that have one tensor
|
||||
// input and one tensor output.
|
||||
|
|
|
@ -18,6 +18,24 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Create an overwrite in a manner that preserves the
|
||||
// `OverwriteTensorContentsOp` invariant that both arguments
|
||||
// must have the same shape and dtype.
|
||||
static void createOverwriteTensorContents(PatternRewriter &rewriter,
|
||||
Location loc, Value overwriterTensor,
|
||||
Value overwrittenTensor) {
|
||||
Type overwriterTensorType = overwriterTensor.getType();
|
||||
Type overwrittenTensorType = overwrittenTensor.getType()
|
||||
.dyn_cast<NonValueTensorType>()
|
||||
.getWithValueSemantics();
|
||||
if (overwriterTensorType != overwrittenTensorType) {
|
||||
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
|
||||
loc, overwrittenTensorType, overwriterTensor);
|
||||
}
|
||||
rewriter.create<OverwriteTensorContentsOp>(loc, overwriterTensor,
|
||||
overwrittenTensor);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Convert value semantic ops operating on mutable arrays to instead operate on
|
||||
// immutable tensors.
|
||||
|
@ -143,7 +161,7 @@ public:
|
|||
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(loc, newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(loc, tensor, op->getOperand(0));
|
||||
createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
|
@ -180,7 +198,8 @@ public:
|
|||
Operation *newOp = rewriter.createOperation(state);
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||
op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
|
||||
return success();
|
||||
|
|
|
@ -2105,7 +2105,44 @@ void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
|||
if (isSafeToRefineOperandInPlace(use, refinedType)) {
|
||||
use->set(newTypedValue);
|
||||
continue;
|
||||
} else if (auto overwriteTensorContents =
|
||||
dyn_cast<OverwriteTensorContentsOp>(
|
||||
use->getOwner())) {
|
||||
// `OverwriteTensorContentsOp` has special handling here because
|
||||
// it requires that both of its operands always have the same
|
||||
// shape and dtype.
|
||||
//
|
||||
// WARNING: In order to simplify the implementation, the type
|
||||
// used for both operands is the type of the overwritten tensor.
|
||||
// A better way of doing this would be to join the two operand
|
||||
// types to create the most specific type possible and use that
|
||||
// for both arguments, allowing static sizes to always propagate.
|
||||
const unsigned overwriterOperandIndex = 0;
|
||||
const unsigned overwrittenOperandIndex = 1;
|
||||
unsigned operandNumber = use->getOperandNumber();
|
||||
if (operandNumber != overwrittenOperandIndex)
|
||||
continue;
|
||||
|
||||
Location loc = overwriteTensorContents.getLoc();
|
||||
Value overwriterTensor = overwriteTensorContents.value();
|
||||
Type overwriterTensorType = overwriterTensor.getType();
|
||||
Type overwrittenTensorType = newTypedValue.getType()
|
||||
.dyn_cast<NonValueTensorType>()
|
||||
.getWithValueSemantics();
|
||||
if (overwriterTensorType == overwrittenTensorType)
|
||||
continue;
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPoint(overwriteTensorContents);
|
||||
Value castedOverwriterTensor = b.create<TensorStaticInfoCastOp>(
|
||||
loc, overwrittenTensorType, overwriterTensor);
|
||||
overwriteTensorContents.setOperand(overwriterOperandIndex,
|
||||
castedOverwriterTensor);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// If needed, create a value of the original type to appease users
|
||||
// that cannot accept the new type.
|
||||
if (!oldTypedValue) {
|
||||
|
|
|
@ -169,3 +169,13 @@ builtin.func @torch.prim.ListConstruct() {
|
|||
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<!torch.tensor>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
|
||||
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
|
||||
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
|
||||
return %1 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ func @torch.copy.tensor$basic(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.
|
|||
func @one_mutation_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
%equal_to_arg0 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg1 = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
return %equal_to_arg0, %equal_to_arg1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
@ -34,12 +34,12 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
|
|||
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
// Overwrite with %arg1
|
||||
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg1 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
%equal_to_arg1_again = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
// Overwrite with %arg2
|
||||
torch.overwrite.tensor %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %arg2 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
%equal_to_arg2 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
|
||||
return %equal_to_arg0, %equal_to_arg1, %equal_to_arg1_again, %equal_to_arg2 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.vtensor
|
||||
|
@ -52,7 +52,7 @@ func @multiple_mutations_in_a_block(%arg0: !torch.vtensor, %arg1: !torch.vtensor
|
|||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter: !torch.vtensor, %int_list: !torch.list<!torch.int>) -> !torch.vtensor {
|
||||
%t = torch.copy.to_tensor %value_t : !torch.tensor
|
||||
torch.overwrite.tensor %overwriter overwrites %t : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %overwriter overwrites %t : !torch.vtensor, !torch.tensor
|
||||
%view = torch.aten.view %t, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||
%result = torch.aten.permute %view, %int_list : !torch.tensor, !torch.list<!torch.int> -> !torch.tensor
|
||||
%value_result = torch.copy.to_vtensor %result : !torch.vtensor
|
||||
|
@ -60,10 +60,10 @@ func @mutation_followed_by_view_like_ops(%value_t: !torch.vtensor, %overwriter:
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @unmodeled_mutation(
|
||||
// CHECK: torch.overwrite.tensor
|
||||
// CHECK: torch.overwrite.tensor.contents
|
||||
func @unmodeled_mutation(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
torch.overwrite.tensor %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor, !torch.tensor
|
||||
"some.op"(%0) : (!torch.tensor) -> ()
|
||||
%result = torch.copy.to_vtensor %0 : !torch.vtensor
|
||||
return %result : !torch.vtensor
|
||||
|
@ -76,7 +76,7 @@ func @unimplemented_control_flow(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %
|
|||
%tensor = torch.copy.to_tensor %arg0 : !torch.tensor
|
||||
%equal_to_arg0 = torch.copy.to_vtensor %tensor : !torch.vtensor
|
||||
torch.prim.If %cond -> () {
|
||||
torch.overwrite.tensor %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %tensor : !torch.vtensor, !torch.tensor
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.If.yield
|
||||
|
|
|
@ -95,7 +95,7 @@ func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
|
|||
// being applied in sequence.
|
||||
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
|
||||
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
|
||||
// CHECK: torch.overwrite.tensor %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {
|
||||
%c1 = torch.constant.int 1
|
||||
|
@ -138,7 +138,7 @@ func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f
|
|||
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
|
||||
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
|
||||
|
@ -153,7 +153,7 @@ func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.fl
|
|||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.bernoulli.float %[[T_VTENSOR]], %[[P]], %[[GENERATOR]] : !torch.vtensor, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
||||
%generator = torch.constant.none
|
||||
|
@ -169,7 +169,7 @@ func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor {
|
|||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.fill.Scalar %[[T_VTENSOR]], %[[VALUE]] : !torch.vtensor, !torch.int -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
|
||||
%value = torch.constant.int 1
|
||||
|
|
|
@ -1157,3 +1157,35 @@ func @torch.aten.BinaryBroadcasting(%arg0: !torch.vtensor<[5,4,3,3,1],f32>, %arg
|
|||
%0 = torch.aten.add.Tensor %arg0, %arg1, %arg2: !torch.vtensor<[5,4,3,3,1],f32>, !torch.vtensor<[?,3,1,2],f32>, !torch.int -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<[2],f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32>, !torch.tensor<[2],f32>
|
||||
func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
|
||||
%static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor
|
||||
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor
|
||||
%static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor
|
||||
%result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32>
|
||||
return %result : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$static_overwrites_dynamic(
|
||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[STATIC_COPY:.*]] : !torch.vtensor<[2],f32> to !torch.vtensor<[?],f32>
|
||||
// CHECK: torch.overwrite.tensor.contents %[[CAST]] overwrites %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32>, !torch.tensor<[?],f32>
|
||||
func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor
|
||||
%dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
%dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor
|
||||
torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor
|
||||
%dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor
|
||||
%result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32>
|
||||
return %result : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue