mirror of https://github.com/llvm/torch-mlir
Make `torch.copy.tensor` canonicalization a bit smarter.
This removes most of the trivial cases that MaximizeValueSemantics needs to handle, making it easier to see the nontrivial cases.pull/233/head
parent
40369c54dc
commit
78d2cc0818
|
@ -33,6 +33,25 @@ def MmModule_chained(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
# A subgraph with multiple mm ops.
|
||||
class MmDagModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 4], torch.float32, True),
|
||||
([4, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, torch.mm(lhs, rhs))
|
||||
|
||||
@register_test_case(module_factory=lambda: MmDagModule())
|
||||
def MmDagModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -922,7 +922,9 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", []> {
|
||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
]> {
|
||||
let summary = "Makes a copy of a tensor.";
|
||||
let description = [{
|
||||
Changes to the original tensor will not be reflected in the copy.
|
||||
|
|
|
@ -528,26 +528,41 @@ OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// y = torch.copy.tensor(hasOneUse@torch.copy.tensor(x)) -> x
|
||||
// Only safe when `y` and `x` have value semantics.
|
||||
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
|
||||
// Only safe when `y` and `x` have value semantics, and
|
||||
// all users of the intermediate tensor op treat the tensor as if it
|
||||
// had value semantics (even if it is a NonValueTensorType).
|
||||
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
||||
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
||||
if (!otherCopy)
|
||||
return failure();
|
||||
if (otherCopy.getOperand().getType().isa<ValueTensorType>() &&
|
||||
op.getResult().getType().isa<ValueTensorType>() &&
|
||||
op.getOperand().hasOneUse()) {
|
||||
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
|
||||
!op.getResult().getType().isa<ValueTensorType>())
|
||||
return failure();
|
||||
// TODO: Use a proper interface here.
|
||||
// MemoryEffectOpInterface is not powerful enough because it cannot model
|
||||
// aliasing. We don't just care that the user is readonly -- we care also
|
||||
// whether it creates an alias. Basically, we care if the user "treats the
|
||||
// tensor as if it has value semantics".
|
||||
// For now, just hardcode the important case of multiple CopyTensorOp users.
|
||||
if (llvm::all_of(op.getOperand().getUsers(),
|
||||
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
|
||||
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
||||
// TODO: Implement MemoryEffectOpInterface to handle the value/non-value
|
||||
// cases precisely. In this case, we specifically know that `otherCopy`
|
||||
// is dead so eagerly clean it up.
|
||||
rewriter.eraseOp(otherCopy);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
void CopyTensorOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (getResult().getType().isa<NonValueTensorType>())
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
||||
if (getOperand().getType().isa<NonValueTensorType>())
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToBuiltinTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -179,3 +179,13 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
|
|||
}
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
|
||||
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
||||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
return %1, %2 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue