mirror of https://github.com/llvm/torch-mlir
Make `torch.copy.tensor` canonicalization a bit smarter.
This removes most of the trivial cases that MaximizeValueSemantics needs to handle, making it easier to see the nontrivial cases.pull/233/head
parent
40369c54dc
commit
78d2cc0818
|
@ -33,6 +33,25 @@ def MmModule_chained(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
# A subgraph with multiple mm ops.
|
||||||
|
class MmDagModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([4, 4], torch.float32, True),
|
||||||
|
([4, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, torch.mm(lhs, rhs))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MmDagModule())
|
||||||
|
def MmDagModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -922,7 +922,9 @@ def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", []> {
|
def Torch_CopyTensorOp : Torch_Op<"copy.tensor", [
|
||||||
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||||
|
]> {
|
||||||
let summary = "Makes a copy of a tensor.";
|
let summary = "Makes a copy of a tensor.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Changes to the original tensor will not be reflected in the copy.
|
Changes to the original tensor will not be reflected in the copy.
|
||||||
|
|
|
@ -528,26 +528,41 @@ OpFoldResult CopyTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void CopyTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
// y = torch.copy.tensor(hasOneUse@torch.copy.tensor(x)) -> x
|
// y = torch.copy.tensor(torch.copy.tensor(x)) -> x
|
||||||
// Only safe when `y` and `x` have value semantics.
|
// Only safe when `y` and `x` have value semantics, and
|
||||||
|
// all users of the intermediate tensor op treat the tensor as if it
|
||||||
|
// had value semantics (even if it is a NonValueTensorType).
|
||||||
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](CopyTensorOp op, PatternRewriter &rewriter) {
|
||||||
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
auto otherCopy = op.getOperand().getDefiningOp<CopyTensorOp>();
|
||||||
if (!otherCopy)
|
if (!otherCopy)
|
||||||
return failure();
|
return failure();
|
||||||
if (otherCopy.getOperand().getType().isa<ValueTensorType>() &&
|
if (!otherCopy.getOperand().getType().isa<ValueTensorType>() ||
|
||||||
op.getResult().getType().isa<ValueTensorType>() &&
|
!op.getResult().getType().isa<ValueTensorType>())
|
||||||
op.getOperand().hasOneUse()) {
|
return failure();
|
||||||
|
// TODO: Use a proper interface here.
|
||||||
|
// MemoryEffectOpInterface is not powerful enough because it cannot model
|
||||||
|
// aliasing. We don't just care that the user is readonly -- we care also
|
||||||
|
// whether it creates an alias. Basically, we care if the user "treats the
|
||||||
|
// tensor as if it has value semantics".
|
||||||
|
// For now, just hardcode the important case of multiple CopyTensorOp users.
|
||||||
|
if (llvm::all_of(op.getOperand().getUsers(),
|
||||||
|
[](Operation *op) { return isa<CopyTensorOp>(op); })) {
|
||||||
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
rewriter.replaceOp(op, {otherCopy.getOperand()});
|
||||||
// TODO: Implement MemoryEffectOpInterface to handle the value/non-value
|
|
||||||
// cases precisely. In this case, we specifically know that `otherCopy`
|
|
||||||
// is dead so eagerly clean it up.
|
|
||||||
rewriter.eraseOp(otherCopy);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CopyTensorOp::getEffects(
|
||||||
|
SmallVectorImpl<SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>>
|
||||||
|
&effects) {
|
||||||
|
if (getResult().getType().isa<NonValueTensorType>())
|
||||||
|
effects.emplace_back(MemoryEffects::Allocate::get(), getResult());
|
||||||
|
if (getOperand().getType().isa<NonValueTensorType>())
|
||||||
|
effects.emplace_back(MemoryEffects::Read::get(), getOperand());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ToBuiltinTensorOp
|
// ToBuiltinTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -179,3 +179,13 @@ func @torch.prim.If$erase_dead_branch(%arg0: !torch.int) -> !torch.int {
|
||||||
}
|
}
|
||||||
return %0 : !torch.int
|
return %0 : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.copy.tensor$untouched_nonval(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
|
// CHECK-NEXT: return %[[ARG]], %[[ARG]] : !torch.vtensor, !torch.vtensor
|
||||||
|
func @torch.copy.tensor$untouched_nonval(%arg0: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
|
%0 = torch.copy.tensor %arg0 : !torch.vtensor -> !torch.tensor
|
||||||
|
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||||
|
%2 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||||
|
return %1, %2 : !torch.vtensor, !torch.vtensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue