Skip to content

Commit

Permalink
Fix(argparse): Make this work with arbitrary nested subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
fkglr committed Jan 4, 2024
1 parent 0f2dfbb commit 84ee28e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# SPDX-FileCopyrightText: Hayden Richards
#
# SPDX-License-Identifier: MIT

"""Namespaces Utility Functions for Declarative Typed Argument Parsing.
The `namespaces` module contains a utility function used for recursively
Expand All @@ -22,7 +18,7 @@ def to_dict(namespace: argparse.Namespace) -> Dict[str, Any]:
Dict[str, Any]: Nested dictionary generated from namespace.
"""
# Get Dictionary from Namespace Vars
dictionary = vars(namespace)
dictionary = dict(vars(namespace))

# Loop Through Dictionary
for key, value in dictionary.items():
Expand Down
45 changes: 16 additions & 29 deletions vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# SPDX-FileCopyrightText: Hayden Richards
#
# SPDX-License-Identifier: MIT

"""Utilities to help with parsing arbitrarily nested `pydantic` models."""

from argparse import Namespace
Expand All @@ -24,39 +20,37 @@ def __init__(
) -> None:
self.model = model
self.args = to_dict(namespace)
self.subcommand = False
self.schema: Dict[str, Any] = self._get_nested_model_fields(self.model)
self.schema: Dict[str, Any] = self._get_nested_model_fields(self.model, namespace)
self.schema = self._remove_null_leaves(self.schema)

if self.subcommand:
# if there are subcommands, they should only be in the topmost
# level, and the way that the unnesting works is
# that it will populate all subcommands,
# so we need to remove the subcommands that were
# not passed at cli
def _get_nested_model_fields(self, model: ModelT, namespace: Namespace, parent: Optional[Tuple] = None):
def contains_subcommand(namespace: Namespace, subcommand: str):
for name, obj in vars(namespace).items():
if isinstance(obj, Namespace):
if name == subcommand:
return True
elif contains_subcommand(obj, subcommand):
return True

# the command should be the very first argument
# after executable/file name
command = list(self.args.keys())[0]
self.schema = self._unset_subcommands(self.schema, command)
return False

def _get_nested_model_fields(self, model: ModelT, parent: Optional[Tuple] = None):
model_fields: Dict[str, Any] = dict()

for field in PydanticField.parse_model(model):
key = field.name

if field.is_a(BaseModel):
if field.is_subcommand():
self.subcommand = True
if not contains_subcommand(namespace, key):
continue

new_parent = (*parent, key) if parent is not None else (key,)
parent = (*parent, key) if parent is not None else (key,)

# recursively build nestes pydantic models in dict,
# which matches the actual schema the nested
# schema pydantic will be expecting
model_fields[key] = self._get_nested_model_fields(
field.model_type, new_parent
field.model_type, namespace, parent
)
else:
# start with all leaves as None unless key is in top level
Expand All @@ -68,12 +62,8 @@ def _get_nested_model_fields(self, model: ModelT, parent: Optional[Tuple] = None
# check full path first
# TODO: this may not be needed depending on how nested namespaces work
# since the arg groups are not nested -- just flattened
full_path = (*parent, key)
value = get_path(self.args, full_path, value)

if value is None:
short_path = (parent[0], key)
value = get_path(self.args, short_path, value)
path = (*parent, key)
value = get_path(self.args, path, value)

model_fields[key] = value

Expand All @@ -89,9 +79,6 @@ def _remove_null_leaves(self, schema: Dict[str, Any]):
# the schema
return remap(schema, visit=lambda p, k, v: v is not None)

def _unset_subcommands(self, schema: Dict[str, Any], command: str):
return {key: value for key, value in schema.items() if key == command}

def validate(self):
"""Return an instance of the `pydantic` modeled validated with data passed from the command line."""
return self.model.model_validate(self.schema)

0 comments on commit 84ee28e

Please sign in to comment.