Make `torch.copy.tensor` canonicalization a bit smarter.

This removes most of the trivial cases that MaximizeValueSemantics needs
to handle, making it easier to see the nontrivial cases.
pull/233/head
Sean Silva 2021-06-17 16:29:20 -07:00
parent 40369c54dc
commit 78d2cc0818
4 changed files with 56 additions and 10 deletions

View File

@ -33,6 +33,25 @@ def MmModule_chained(module, tu: TestUtils):
# ==============================================================================
# A subgraph with multiple mm ops.
class MmDagModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 4], torch.float32, True),
([4, 4], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.mm(lhs, torch.mm(lhs, rhs))
@register_test_case(module_factory=lambda: MmDagModule())
def MmDagModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# ==============================================================================
class TanhModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -922,7 +922,9 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
}];
}
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", []> {
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let summary = "Makes a copy of a tensor.";
let description = [{
Changes to the original tensor will not be reflected in the copy.

View File

@ -528,26 +528,41 @@ OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// y = torch.copy.tensor(hasOneUse@torch.copy.tensor(x)) -> x
// Only safe when `y` and `x` have value semantics.
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
// Only safe when `y` and `x` have value semantics, and
// all users of the intermediate tensor op treat the tensor as if it
// had value semantics (even if it is a NonValueTensorType).
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
if (!otherCopy)
return failure();
if (otherCopy.getOperand().getType().isa<ValueTensorType>() &&
op.getResult().getType().isa<ValueTensorType>() &&
op.getOperand().hasOneUse()) {
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
!op.getResult().getType().isa<ValueTensorType>())
return failure();
// TODO: Use a proper interface here.
// MemoryEffectOpInterface is not powerful enough because it cannot model
// aliasing. We don't just care that the user is readonly -- we care also
// whether it creates an alias. Basically, we care if the user "treats the
// tensor as if it has value semantics".
// For now, just hardcode the important case of multiple CopyTensorOp users.
if (llvm::all_of(op.getOperand().getUsers(),
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
rewriter.replaceOp(op, {otherCopy.getOperand()});
// TODO: Implement MemoryEffectOpInterface to handle the value/non-value
// cases precisely. In this case, we specifically know that `otherCopy`
// is dead so eagerly clean it up.
rewriter.eraseOp(otherCopy);
return success();
}
return failure();
});
}
void CopyTensorOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
&effects) {
if (getResult().getType().isa<NonValueTensorType>())
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
if (getOperand().getType().isa<NonValueTensorType>())
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
}
//===----------------------------------------------------------------------===//
// ToBuiltinTensorOp
//===----------------------------------------------------------------------===//

View File

@ -179,3 +179,13 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
}
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
return %1, %2 : !torch.vtensor, !torch.vtensor
}