mirror of https://github.com/llvm/torch-mlir
Use `register_buffer` to make `Add_Module` test work on lazy tensor (#2332)
Doing `module.to('lazy')` only moves the module member tensors to the device if they are created with `self.register_buffer` or `self.register_parameter`. Since the `self.tensor` tensor in `Add_Module` test is currently not created using the `self.register_*` methods, it is not being moved from CPU to lazy device, which is causing the test to fail on LTC backend. This commit uses `self.register_buffer` to fix the test on LTC backend. This commit also seems to fix the test for torchdynamo.pull/2294/head
parent
ef11a77315
commit
4a96e716c0
|
@ -33,7 +33,6 @@ from .xfail_sets import (
|
|||
STABLEHLO_PASS_SET,
|
||||
TOSA_PASS_SET,
|
||||
LTC_XFAIL_SET,
|
||||
LTC_CRASHING_SET,
|
||||
TORCHDYNAMO_XFAIL_SET,
|
||||
TORCHDYNAMO_CRASHING_SET
|
||||
)
|
||||
|
@ -114,7 +113,7 @@ def main():
|
|||
elif args.config == "lazy_tensor_core":
|
||||
config = LazyTensorCoreTestConfig()
|
||||
xfail_set = LTC_XFAIL_SET
|
||||
crashing_set = LTC_CRASHING_SET
|
||||
crashing_set = set()
|
||||
elif args.config == "torchdynamo":
|
||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||
|
|
|
@ -307,9 +307,6 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"TransposeIntModule_basic",
|
||||
"TransposeIntNegDimsModule_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
|
||||
# See https://github.com/llvm/torch-mlir/issues/2178
|
||||
"Add_Module_basic"
|
||||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
|
@ -1169,11 +1166,6 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
|||
"ReshapeCollapseModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
# https://github.com/llvm/torch-mlir/issues/2186
|
||||
"Add_Module_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
"_Convolution2DAllFalseModule_basic",
|
||||
"_Convolution2DBenchmarkModule_basic",
|
||||
|
|
|
@ -4267,7 +4267,7 @@ class Add_Module(torch.nn.Module):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tensor = torch.ones(2, 3)
|
||||
self.register_buffer('tensor', torch.ones(2, 3))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
|
|
Loading…
Reference in New Issue