diff --git a/frontends/pytorch/e2e_testing/torchscript/basic.py b/frontends/pytorch/e2e_testing/torchscript/basic.py index fc1a53650..acb243f82 100644 --- a/frontends/pytorch/e2e_testing/torchscript/basic.py +++ b/frontends/pytorch/e2e_testing/torchscript/basic.py @@ -33,6 +33,25 @@ def MmModule_chained(module, tu: TestUtils): # ============================================================================== +# A subgraph with multiple mm ops. +class MmDagModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([4, 4], torch.float32, True), + ([4, 4], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.mm(lhs, torch.mm(lhs, rhs)) + +@register_test_case(module_factory=lambda: MmDagModule()) +def MmDagModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + +# ============================================================================== + class TanhModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 20dea3ea6..521d5049e 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -922,7 +922,9 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [ }]; } -def Torch_CopyTensorOp : Torch_Op<"copy.tensor", []> { +def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [ + DeclareOpInterfaceMethods, + ]> { let summary = "Makes a copy of a tensor."; let description = [{ Changes to the original tensor will not be reflected in the copy. diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ead9d1e5f..a50b91fdc 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -528,26 +528,41 @@ OpFoldResult CopyTensorOp::fold(ArrayRef operands) { void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - // y = torch.copy.tensor(hasOneUse@torch.copy.tensor(x)) -> x - // Only safe when `y` and `x` have value semantics. + // y = torch.copy.tensor(torch.copy.tensor(x)) -> x + // Only safe when `y` and `x` have value semantics, and + // all users of the intermediate tensor op treat the tensor as if it + // had value semantics (even if it is a NonValueTensorType). patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) { auto otherCopy = op.getOperand().getDefiningOp(); if (!otherCopy) return failure(); - if (otherCopy.getOperand().getType().isa() && - op.getResult().getType().isa() && - op.getOperand().hasOneUse()) { + if (!otherCopy.getOperand().getType().isa() || + !op.getResult().getType().isa()) + return failure(); + // TODO: Use a proper interface here. + // MemoryEffectOpInterface is not powerful enough because it cannot model + // aliasing. We don't just care that the user is readonly -- we care also + // whether it creates an alias. Basically, we care if the user "treats the + // tensor as if it has value semantics". + // For now, just hardcode the important case of multiple CopyTensorOp users. + if (llvm::all_of(op.getOperand().getUsers(), + [](Operation *op) { return isa(op); })) { rewriter.replaceOp(op, {otherCopy.getOperand()}); - // TODO: Implement MemoryEffectOpInterface to handle the value/non-value - // cases precisely. In this case, we specifically know that `otherCopy` - // is dead so eagerly clean it up. - rewriter.eraseOp(otherCopy); return success(); } return failure(); }); } +void CopyTensorOp::getEffects( + SmallVectorImpl> + &effects) { + if (getResult().getType().isa()) + effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); + if (getOperand().getType().isa()) + effects.emplace_back(MemoryEffects::Read::get(), getOperand()); +} + //===----------------------------------------------------------------------===// // ToBuiltinTensorOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 0d3616212..febc522ea 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -179,3 +179,13 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int { } return %0 : !torch.int } + +// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { +// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor +func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) { + %0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor + %1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor + %2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor + return %1, %2 : !torch.vtensor, !torch.vtensor +}