mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix name conflict of `self` keyword.
- The `self` name is being used as a keyword argument to the `torch.ops.aten.nll_loss_backward` function call, which produces name-conflict error with the python keyword `self` which is pointer to the current object. - This commit fixes this issue by replacing the keyword argument by positional argument. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/569/head
parent
3dc7847348
commit
dcef4751f9
|
@ -75,13 +75,13 @@ class NllLossModule_backward(torch.nn.Module):
|
|||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
|
||||
self=input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
||||
|
@ -104,8 +104,8 @@ class NllLossModule_backward_ignore_index(torch.nn.Module):
|
|||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
|
||||
self=input,
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
|
|
Loading…
Reference in New Issue