From 84ee28e6d2477ffba0a6c6317b76c376cab8867b Mon Sep 17 00:00:00 2001 From: fabian Date: Thu, 4 Jan 2024 20:23:23 +0100 Subject: [PATCH] Fix(argparse): Make this work with arbitrary nested subcommands --- .../pydantic_argparse/utils/namespaces.py | 6 +-- .../pydantic_argparse/utils/nesting.py | 45 +++++++------------ 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/namespaces.py b/vendor/pydantic-argparse/pydantic_argparse/utils/namespaces.py index 923eae565..19c390429 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/namespaces.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/namespaces.py @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: Hayden Richards -# -# SPDX-License-Identifier: MIT - """Namespaces Utility Functions for Declarative Typed Argument Parsing. The `namespaces` module contains a utility function used for recursively @@ -22,7 +18,7 @@ def to_dict(namespace: argparse.Namespace) -> Dict[str, Any]: Dict[str, Any]: Nested dictionary generated from namespace. """ # Get Dictionary from Namespace Vars - dictionary = vars(namespace) + dictionary = dict(vars(namespace)) # Loop Through Dictionary for key, value in dictionary.items(): diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py index 6b18ea923..0bef55acf 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/nesting.py @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: Hayden Richards -# -# SPDX-License-Identifier: MIT - """Utilities to help with parsing arbitrarily nested `pydantic` models.""" from argparse import Namespace @@ -24,23 +20,20 @@ def __init__( ) -> None: self.model = model self.args = to_dict(namespace) - self.subcommand = False - self.schema: Dict[str, Any] = self._get_nested_model_fields(self.model) + self.schema: Dict[str, Any] = self._get_nested_model_fields(self.model, namespace) self.schema = self._remove_null_leaves(self.schema) - if self.subcommand: - # if there are subcommands, they should only be in the topmost - # level, and the way that the unnesting works is - # that it will populate all subcommands, - # so we need to remove the subcommands that were - # not passed at cli + def _get_nested_model_fields(self, model: ModelT, namespace: Namespace, parent: Optional[Tuple] = None): + def contains_subcommand(namespace: Namespace, subcommand: str): + for name, obj in vars(namespace).items(): + if isinstance(obj, Namespace): + if name == subcommand: + return True + elif contains_subcommand(obj, subcommand): + return True - # the command should be the very first argument - # after executable/file name - command = list(self.args.keys())[0] - self.schema = self._unset_subcommands(self.schema, command) + return False - def _get_nested_model_fields(self, model: ModelT, parent: Optional[Tuple] = None): model_fields: Dict[str, Any] = dict() for field in PydanticField.parse_model(model): @@ -48,15 +41,16 @@ def _get_nested_model_fields(self, model: ModelT, parent: Optional[Tuple] = None if field.is_a(BaseModel): if field.is_subcommand(): - self.subcommand = True + if not contains_subcommand(namespace, key): + continue - new_parent = (*parent, key) if parent is not None else (key,) + parent = (*parent, key) if parent is not None else (key,) # recursively build nestes pydantic models in dict, # which matches the actual schema the nested # schema pydantic will be expecting model_fields[key] = self._get_nested_model_fields( - field.model_type, new_parent + field.model_type, namespace, parent ) else: # start with all leaves as None unless key is in top level @@ -68,12 +62,8 @@ def _get_nested_model_fields(self, model: ModelT, parent: Optional[Tuple] = None # check full path first # TODO: this may not be needed depending on how nested namespaces work # since the arg groups are not nested -- just flattened - full_path = (*parent, key) - value = get_path(self.args, full_path, value) - - if value is None: - short_path = (parent[0], key) - value = get_path(self.args, short_path, value) + path = (*parent, key) + value = get_path(self.args, path, value) model_fields[key] = value @@ -89,9 +79,6 @@ def _remove_null_leaves(self, schema: Dict[str, Any]): # the schema return remap(schema, visit=lambda p, k, v: v is not None) - def _unset_subcommands(self, schema: Dict[str, Any], command: str): - return {key: value for key, value in schema.items() if key == command} - def validate(self): """Return an instance of the `pydantic` modeled validated with data passed from the command line.""" return self.model.model_validate(self.schema)