diff --git a/python/torch_mlir/eager_mode/torch_mlir_tensor.py b/python/torch_mlir/eager_mode/torch_mlir_tensor.py index 0bb593222..aae1e6581 100644 --- a/python/torch_mlir/eager_mode/torch_mlir_tensor.py +++ b/python/torch_mlir/eager_mode/torch_mlir_tensor.py @@ -131,6 +131,8 @@ class TorchMLIRTensor(torch.Tensor): if UNSUPPORTED_OPS.match(op_name): raise UnsupportedByTorchMlirEagerMode(op_name) + requires_grad = requires_grad and "view" not in op_name + if not hasattr(func, "_schema"): raise RuntimeError(f"op {func} has no schema.")