Fix the input tensors inplace update issue for e2e tests

Fix the inplace update tensor issue we had
where the torchscript execution would update the input value inplace
resulting the actual test not being able to see the original input
value.
pull/797/head
Yi Zhang 2022-03-23 11:34:02 -04:00
parent 44c7b181d3
commit 7be9783f16
2 changed files with 29 additions and 13 deletions

View File

@ -133,10 +133,7 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
([-1], torch.float32, True), ([-1], torch.float32, True),
]) ])
def forward(self, input, index, value): def forward(self, input, index, value):
# Since the input is updated in-place, we pass input.clone() in place return torch.ops.aten._index_put_impl_(input, (index,), value,
# of input to avoid wrong results.
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
value,
accumulate=True, accumulate=True,
unsafe=False) unsafe=False)
@ -214,10 +211,7 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
([-1], torch.int64, True), ([-1], torch.int64, True),
]) ])
def forward(self, input, index, value): def forward(self, input, index, value):
# Since the input is updated in-place, we pass input.clone() in place return torch.ops.aten._index_put_impl_(input, (index,), value,
# of input to avoid wrong results.
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
value,
accumulate=True, accumulate=True,
unsafe=False) unsafe=False)

View File

@ -58,6 +58,24 @@ Trace = List[TraceItem]
# this type. # this type.
CompiledArtifact = TypeVar('CompiledArtifact') CompiledArtifact = TypeVar('CompiledArtifact')
# Clone all the tensor values.
def clone_torch_script_value(v: TorchScriptValue):
if isinstance(v, torch.Tensor):
return v.clone()
if isinstance(v, tuple):
return tuple(clone_torch_script_value(field) for field in v)
if isinstance(v, list):
return [clone_torch_script_value(item) for item in v]
if isinstance(v, dict):
return {
clone_torch_script_value(key): clone_torch_script_value(val)
for key, val in v.items()
}
if isinstance(v, float) or isinstance(v, int) or isinstance(v, str):
return v
assert False, "unhandled cloning of TorchScriptValue value type"
class TestConfig(abc.ABC): class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests. """The interface implemented by backends to run tests.
@ -208,16 +226,20 @@ class _Tracer:
The inputs and outputs of each call are recorded in a Trace. Recursive The inputs and outputs of each call are recorded in a Trace. Recursive
property accesses are also traced. property accesses are also traced.
""" """
def __init__(self, wrapped, property_base_path: List[str], trace: Trace): def __init__(self, wrapped, property_base_path: List[str], trace: Trace):
self.__wrapped__ = wrapped self.__wrapped__ = wrapped
self.__trace__ = trace self.__trace__ = trace
self.__property_base_path__ = property_base_path self.__property_base_path__ = property_base_path
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Clone the inputs to capture the original tensors values. This is
# needed because inplace mutation might happen to the input tensors.
inputs = [clone_torch_script_value(arg) for arg in args]
output = self.__wrapped__(*args, **kwargs) output = self.__wrapped__(*args, **kwargs)
self.__trace__.append( self.__trace__.append(
TraceItem(symbol=".".join(self.__property_base_path__), TraceItem(symbol=".".join(self.__property_base_path__),
inputs=args, inputs=inputs,
output=output)) output=output))
return output return output