Skip to content

Commit

Permalink
fully support CatchAll
Browse files Browse the repository at this point in the history
  • Loading branch information
rnag committed Dec 11, 2024
1 parent dde2d83 commit 9a7725a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
24 changes: 15 additions & 9 deletions dataclass_wizard/v1/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def load_func_for_dataclass(
fn_gen.add_line('re_raise(e, cls, o, fields, field, v1)')


req_field_and_var = []
vars_for_fields = []

if cls_init_fields:

Expand Down Expand Up @@ -1058,8 +1058,8 @@ def load_func_for_dataclass(

else:
# TODO confirm this is ok
# req_field_and_var.append(f'{name}={var}')
req_field_and_var.append(var)
# vars_for_fields.append(f'{name}={var}')
vars_for_fields.append(var)

fn_gen.add_line(f_assign)
with fn_gen.if_(f'{val} is not MISSING'):
Expand All @@ -1075,18 +1075,24 @@ def load_func_for_dataclass(
# add an alias for the tag key, so we don't capture it
field_to_alias['...'] = meta.tag_key

# TODO for Auto
_locals['aliases'] = set(field_to_alias.values())
if 'f2k' in _locals:
# If this is the case, then `AUTO` key transform mode is enabled
# line = 'extra_keys = o.keys() - f2k.values()'
aliases_var = 'f2k.values()'

else:
aliases_var = 'aliases'
_locals['aliases'] = set(field_to_alias.values())

catch_all_def = '{k: o[k] for k in o if k not in aliases}'
catch_all_def = f'{{k: o[k] for k in o if k not in {aliases_var}}}'

if catch_all_field.endswith('?'): # Default value
with fn_gen.if_('len(o) != i'):
fn_gen.add_line(f'init_kwargs[{catch_all_field_stripped!r}] = {catch_all_def}')
else:
var = f'__{catch_all_field_stripped}'
fn_gen.add_line(f'{var} = {{}} if len(o) == i else {catch_all_def}')
req_field_and_var.insert(catch_all_idx, var)
vars_for_fields.insert(catch_all_idx, var)

elif should_warn or should_raise:
if expect_tag_as_unknown_key:
Expand Down Expand Up @@ -1117,8 +1123,8 @@ def load_func_for_dataclass(
# we raise them here.

if has_defaults:
req_field_and_var.append('**init_kwargs')
init_parts = ', '.join(req_field_and_var)
vars_for_fields.append('**init_kwargs')
init_parts = ', '.join(vars_for_fields)
with fn_gen.try_():
fn_gen.add_line(f"return cls({init_parts})")
with fn_gen.except_(UnboundLocalError):
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/v1/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,31 @@ class _(JSONWizard.Meta):
assert new_data.extra_data is False


def test_catch_all_with_auto_key_case():
"""'Catch All' with `auto` key case."""

@dataclass
class Options(JSONWizard):
class _(JSONWizard.Meta):
v1 = True
debug_enabled = True
v1_key_case = 'Auto'

my_extras: CatchAll
email: str

opt = Options.from_dict({
'Email': '[email protected]',
'token': '<PASSWORD>',
})
assert opt == Options(my_extras={'token': '<PASSWORD>'}, email='[email protected]')

opt = Options.from_dict({
'Email': '[email protected]',
})
assert opt == Options(my_extras={}, email='[email protected]')


@pytest.mark.xfail(reason='TODO add support in v1')
def test_from_dict_with_nested_object_key_path():
"""
Expand Down

0 comments on commit 9a7725a

Please sign in to comment.