mirror of https://github.com/llvm/torch-mlir
Small bug fixes in eager mode (#691)
parent
1960ba76fb
commit
3e999beaea
|
@ -194,6 +194,9 @@ def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
|||
name_attr = ir.StringAttr.get(name)
|
||||
for op in module.body.operations:
|
||||
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||
# Add name of torch op as debug_module_name so that
|
||||
# run_pipeline_with_repro_report can use it.
|
||||
module.operation.attributes["torch.debug_module_name"] = name_attr
|
||||
return op
|
||||
|
||||
return None
|
||||
|
|
|
@ -105,6 +105,13 @@ def build_script_function(
|
|||
inp.setDebugName(arg.name)
|
||||
# If arg is a constant, inline (at the top of the graph).
|
||||
else:
|
||||
if val == []:
|
||||
# Some ops have empty list default values for args
|
||||
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
|
||||
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
|
||||
# This might be an upstream bug but there doesn't seem to be a way to
|
||||
# build a prim::ListConstruct list that's empty.
|
||||
val = None
|
||||
inp = graph.insertConstant(val)
|
||||
inp.node().moveBefore(node)
|
||||
|
||||
|
@ -219,7 +226,7 @@ def try_torch_mlir_eager(op, args, kwargs, backend):
|
|||
else:
|
||||
raise RuntimeError(f"op {op} has no name")
|
||||
|
||||
if op_name == "detach":
|
||||
if "detach" in op_name:
|
||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||
raise UnsupportedByTorchMlirEagerMode("detaching")
|
||||
|
||||
|
|
|
@ -41,7 +41,9 @@ class TorchMLIRTensor(torch.Tensor):
|
|||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=kwargs.get("requires_grad", False) or elem.requires_grad,
|
||||
# Only float tensors can have gradients.
|
||||
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
|
||||
and (kwargs.get("requires_grad", False) or elem.requires_grad),
|
||||
)
|
||||
r.elem = elem.detach() if r.requires_grad else elem
|
||||
return r
|
||||
|
|
Loading…
Reference in New Issue