diff --git a/python/torch_mlir_e2e_test/test_suite/index_put.py b/python/torch_mlir_e2e_test/test_suite/index_put.py index 6b0b650bb..838f0496d 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_put.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -133,12 +133,9 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module): ([-1], torch.float32, True), ]) def forward(self, input, index, value): - # Since the input is updated in-place, we pass input.clone() in place - # of input to avoid wrong results. - return torch.ops.aten._index_put_impl_(input.clone(), (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=True, + unsafe=False) @register_test_case( @@ -214,12 +211,9 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module): ([-1], torch.int64, True), ]) def forward(self, input, index, value): - # Since the input is updated in-place, we pass input.clone() in place - # of input to avoid wrong results. - return torch.ops.aten._index_put_impl_(input.clone(), (index, ), - value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=True, + unsafe=False) @register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) diff --git a/python/torch_mlir_e2e_test/torchscript/framework.py b/python/torch_mlir_e2e_test/torchscript/framework.py index 4dd3bdea9..135969df2 100644 --- a/python/torch_mlir_e2e_test/torchscript/framework.py +++ b/python/torch_mlir_e2e_test/torchscript/framework.py @@ -58,6 +58,24 @@ Trace = List[TraceItem] # this type. CompiledArtifact = TypeVar('CompiledArtifact') +# Clone all the tensor values. +def clone_torch_script_value(v: TorchScriptValue): + if isinstance(v, torch.Tensor): + return v.clone() + if isinstance(v, tuple): + return tuple(clone_torch_script_value(field) for field in v) + if isinstance(v, list): + return [clone_torch_script_value(item) for item in v] + if isinstance(v, dict): + return { + clone_torch_script_value(key): clone_torch_script_value(val) + for key, val in v.items() + } + if isinstance(v, float) or isinstance(v, int) or isinstance(v, str): + return v + assert False, "unhandled cloning of TorchScriptValue value type" + + class TestConfig(abc.ABC): """The interface implemented by backends to run tests. @@ -208,16 +226,20 @@ class _Tracer: The inputs and outputs of each call are recorded in a Trace. Recursive property accesses are also traced. """ + def __init__(self, wrapped, property_base_path: List[str], trace: Trace): self.__wrapped__ = wrapped self.__trace__ = trace self.__property_base_path__ = property_base_path def __call__(self, *args, **kwargs): + # Clone the inputs to capture the original tensors values. This is + # needed because inplace mutation might happen to the input tensors. + inputs = [clone_torch_script_value(arg) for arg in args] output = self.__wrapped__(*args, **kwargs) self.__trace__.append( TraceItem(symbol=".".join(self.__property_base_path__), - inputs=args, + inputs=inputs, output=output)) return output