Fix https://github.com/llvm/torch-mlir/issues/1618 by stripping `requires_grad` from results of view ops. (#1624)

pull/1597/head
Maksim Levental 2022-11-21 19:15:53 -06:00 committed by GitHub
parent 22307a1427
commit ed901094c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -131,6 +131,8 @@ class TorchMLIRTensor(torch.Tensor):
if UNSUPPORTED_OPS.match(op_name):
raise UnsupportedByTorchMlirEagerMode(op_name)
requires_grad = requires_grad and "view" not in op_name
if not hasattr(func, "_schema"):
raise RuntimeError(f"op {func} has no schema.")