mirror of https://github.com/llvm/torch-mlir
Fix https://github.com/llvm/torch-mlir/issues/1618 by stripping `requires_grad` from results of view ops. (#1624)
parent
22307a1427
commit
ed901094c1
|
@ -131,6 +131,8 @@ class TorchMLIRTensor(torch.Tensor):
|
|||
if UNSUPPORTED_OPS.match(op_name):
|
||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
||||
|
||||
requires_grad = requires_grad and "view" not in op_name
|
||||
|
||||
if not hasattr(func, "_schema"):
|
||||
raise RuntimeError(f"op {func} has no schema.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue