Add `ReadOnly` trait to `copy.to_vtensor` (#2179)

Before inlining a global slot, the users of the global slot are
checked to see if they are `ReadOnly` or `MemoryEffectFree` to make
sure that the global slot is not being mutated. Because the op
`copy.to_vtensor` currently does not have the `ReadOnly` trait, if a
global slot is passed to `copy.to_vtensor`, the pass
`InlineGlobalSlots` will fail.

The op `copy.to_vtensor` is `ReadOnly`, since it does not modify the
contents of the input tensor; it simply makes a new copy. This commit
adds the trait as well as an e2e test that generates the case of a
global slot being passed to a `copy.to_vtensor`.
pull/2188/head
Ramiro Leal-Cavazos 2023-05-30 14:40:36 -07:00 committed by GitHub
parent db3f2e3fde
commit 479b2175ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 1 deletions

View File

@ -32,6 +32,7 @@ from .xfail_sets import (
STABLEHLO_PASS_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
LTC_CRASHING_SET,
TORCHDYNAMO_XFAIL_SET,
TORCHDYNAMO_CRASHING_SET
)
@ -108,7 +109,7 @@ def main():
elif args.config == "lazy_tensor_core":
config = LazyTensorCoreTestConfig()
xfail_set = LTC_XFAIL_SET
crashing_set = set()
crashing_set = LTC_CRASHING_SET
elif args.config == "torchdynamo":
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = TORCHDYNAMO_XFAIL_SET

View File

@ -303,6 +303,9 @@ TORCHDYNAMO_CRASHING_SET = {
"ToCopyModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
# See https://github.com/llvm/torch-mlir/issues/2178
"Add_Module_basic"
}
STABLEHLO_PASS_SET = {
@ -1068,6 +1071,11 @@ TOSA_PASS_SET = {
"ChunkListUnpackUneven_Module_basic",
}
LTC_CRASHING_SET = {
# https://github.com/llvm/torch-mlir/issues/2186
"Add_Module_basic"
}
LTC_XFAIL_SET = {
"_Convolution2DAllFalseModule_basic",
"_Convolution2DBenchmarkModule_basic",

View File

@ -1020,6 +1020,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
}
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
ReadOnly,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor",

View File

@ -3929,3 +3929,26 @@ class AtenComplexViewModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AtenComplexViewModule())
def AtenComplexViewModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,2))
# ==============================================================================
class Add_Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.ones(2, 3)
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.add_(x, self.tensor)
@register_test_case(module_factory=lambda: Add_Module())
def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))