[dynamo] Verify the default value is passed by kwargs (#2998)

pull/3158/merge
penguin_wwy 2024-04-28 02:18:33 +08:00 committed by GitHub
parent f173a06fa7
commit 4fbe77a051
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 11 deletions

View File

@ -252,18 +252,14 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
"ElementwiseAtenFloorDivideScalarNegativeModule_basic",
"ElementwiseAtenFloorDivideScalarModule_basic",
"ElementwiseDivTensorRoundingModeFloorModule_basic",
"ElementwiseDivTensorRoundingModeTruncModule_basic",
"ElementwiseDivTensorRoundingModeFloorStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncStaticModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorModule_basic",
"ElementwiseDivScalarRoundingModeTruncModule_basic",
"ElementwiseDivScalarRoundingModeFloorStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncStaticModule_basic",
"ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic",
"ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic",
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
@ -276,10 +272,6 @@ TORCHDYNAMO_XFAIL_SET = {
"TensorFloatModule_basic",
"TensorIntModule_basic",
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.randn.generator
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
# START tests failing due to: complex floating point ops
# END tests failing due to: complex floating point ops
@ -343,8 +335,10 @@ TORCHDYNAMO_XFAIL_SET = {
"IntImplicitModule_basic",
# Others
"ExponentialModule_basic",
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",

View File

@ -78,7 +78,7 @@ def _verify_fx_graph_conforms_to_subset(g: torch.fx.Graph):
assert len(node.args) < len(node.target._schema.arguments)
for i, argument in enumerate(
node.target._schema.arguments[len(node.args):]):
if not argument.has_default_value():
if not argument.has_default_value() and argument.name not in node.kwargs:
raise Exception(
f"Unsupported: missing default value for argument {i} in schema for {node.target}"
)