mirror of https://github.com/llvm/torch-mlir
Use `register_buffer` to make `Add_Module` test work on lazy tensor (#2332)
Doing `module.to('lazy')` only moves the module member tensors to the device if they are created with `self.register_buffer` or `self.register_parameter`. Since the `self.tensor` tensor in `Add_Module` test is currently not created using the `self.register_*` methods, it is not being moved from CPU to lazy device, which is causing the test to fail on LTC backend. This commit uses `self.register_buffer` to fix the test on LTC backend. This commit also seems to fix the test for torchdynamo.pull/2294/head
parent
ef11a77315
commit
4a96e716c0
|
@ -33,7 +33,6 @@ from .xfail_sets import (
|
||||||
STABLEHLO_PASS_SET,
|
STABLEHLO_PASS_SET,
|
||||||
TOSA_PASS_SET,
|
TOSA_PASS_SET,
|
||||||
LTC_XFAIL_SET,
|
LTC_XFAIL_SET,
|
||||||
LTC_CRASHING_SET,
|
|
||||||
TORCHDYNAMO_XFAIL_SET,
|
TORCHDYNAMO_XFAIL_SET,
|
||||||
TORCHDYNAMO_CRASHING_SET
|
TORCHDYNAMO_CRASHING_SET
|
||||||
)
|
)
|
||||||
|
@ -114,7 +113,7 @@ def main():
|
||||||
elif args.config == "lazy_tensor_core":
|
elif args.config == "lazy_tensor_core":
|
||||||
config = LazyTensorCoreTestConfig()
|
config = LazyTensorCoreTestConfig()
|
||||||
xfail_set = LTC_XFAIL_SET
|
xfail_set = LTC_XFAIL_SET
|
||||||
crashing_set = LTC_CRASHING_SET
|
crashing_set = set()
|
||||||
elif args.config == "torchdynamo":
|
elif args.config == "torchdynamo":
|
||||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||||
|
|
|
@ -307,9 +307,6 @@ TORCHDYNAMO_CRASHING_SET = {
|
||||||
"TransposeIntModule_basic",
|
"TransposeIntModule_basic",
|
||||||
"TransposeIntNegDimsModule_basic",
|
"TransposeIntNegDimsModule_basic",
|
||||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||||
|
|
||||||
# See https://github.com/llvm/torch-mlir/issues/2178
|
|
||||||
"Add_Module_basic"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
@ -1169,11 +1166,6 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||||
"ReshapeCollapseModule_basic",
|
"ReshapeCollapseModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_CRASHING_SET = {
|
|
||||||
# https://github.com/llvm/torch-mlir/issues/2186
|
|
||||||
"Add_Module_basic"
|
|
||||||
}
|
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
"_Convolution2DAllFalseModule_basic",
|
"_Convolution2DAllFalseModule_basic",
|
||||||
"_Convolution2DBenchmarkModule_basic",
|
"_Convolution2DBenchmarkModule_basic",
|
||||||
|
|
|
@ -4267,7 +4267,7 @@ class Add_Module(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tensor = torch.ones(2, 3)
|
self.register_buffer('tensor', torch.ones(2, 3))
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
|
|
Loading…
Reference in New Issue