[FxImporter] Fix kwarg operands in fx importer (#3166)

Remove the `kwarg_only` limitation, for example
```
torch.add(x, 3.0, alpha=2)
```
compiled to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int1
```
fix to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int2
```
pull/3176/head
penguin_wwy 2024-04-17 04:17:05 +08:00 committed by GitHub
parent 7a1ad0d7c0
commit 398aeeec87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 8 deletions

View File

@ -433,14 +433,11 @@ FX_IMPORT_XFAIL_SET = {
"ConvolutionBackwardModule2D_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",

View File

@ -1450,16 +1450,16 @@ class GraphNodeImporter:
# Unroll operands from formal parameters, args and kwargs.
operands = []
for i, parameter in enumerate(schema.arguments):
if parameter.kwarg_only and parameter.name in node.kwargs:
if i < len(node.args):
operands.append(
self._import_argument(loc, node.args[i], parameter.type)
)
elif parameter.name in node.kwargs:
operands.append(
self._import_argument(
loc, node.kwargs[parameter.name], parameter.type
)
)
elif i < len(node.args):
operands.append(
self._import_argument(loc, node.args[i], parameter.type)
)
else:
operands.append(
self._import_default_value(