Use `register_buffer` to make `Add_Module` test work on lazy tensor (#2332)

Doing `module.to('lazy')` only moves the module member tensors to the
device if they are created with `self.register_buffer` or
`self.register_parameter`. Since the `self.tensor` tensor in
`Add_Module` test is currently not created using the `self.register_*`
methods, it is not being moved from CPU to lazy device, which is
causing the test to fail on LTC backend. This commit uses
`self.register_buffer` to fix the test on LTC backend.

This commit also seems to fix the test for torchdynamo.
pull/2294/head
Ramiro Leal-Cavazos 2023-07-24 09:07:13 -07:00 committed by GitHub
parent ef11a77315
commit 4a96e716c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 2 additions and 11 deletions

View File

@ -33,7 +33,6 @@ from .xfail_sets import (
STABLEHLO_PASS_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
LTC_CRASHING_SET,
TORCHDYNAMO_XFAIL_SET,
TORCHDYNAMO_CRASHING_SET
)
@ -114,7 +113,7 @@ def main():
elif args.config == "lazy_tensor_core":
config = LazyTensorCoreTestConfig()
xfail_set = LTC_XFAIL_SET
crashing_set = LTC_CRASHING_SET
crashing_set = set()
elif args.config == "torchdynamo":
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = TORCHDYNAMO_XFAIL_SET

View File

@ -307,9 +307,6 @@ TORCHDYNAMO_CRASHING_SET = {
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
# See https://github.com/llvm/torch-mlir/issues/2178
"Add_Module_basic"
}
STABLEHLO_PASS_SET = {
@ -1169,11 +1166,6 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
"ReshapeCollapseModule_basic",
}
LTC_CRASHING_SET = {
# https://github.com/llvm/torch-mlir/issues/2186
"Add_Module_basic"
}
LTC_XFAIL_SET = {
"_Convolution2DAllFalseModule_basic",
"_Convolution2DBenchmarkModule_basic",

View File

@ -4267,7 +4267,7 @@ class Add_Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.ones(2, 3)
self.register_buffer('tensor', torch.ones(2, 3))
@export
@annotate_args([