Skip to content

Commit

Permalink
Merge pull request #635 from ncoop57/type_casting
Browse files Browse the repository at this point in the history
Add support for type casting in the typed decorator
  • Loading branch information
jph00 authored Oct 9, 2024
2 parents 77afd07 + 9474b9b commit f4a05cf
Show file tree
Hide file tree
Showing 3 changed files with 407 additions and 89 deletions.
9 changes: 9 additions & 0 deletions fastcore/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,17 @@
'fastcore.basics.stop': ('basics.html#stop', 'fastcore/basics.py'),
'fastcore.basics.store_attr': ('basics.html#store_attr', 'fastcore/basics.py'),
'fastcore.basics.str2bool': ('basics.html#str2bool', 'fastcore/basics.py'),
'fastcore.basics.str2date': ('basics.html#str2date', 'fastcore/basics.py'),
'fastcore.basics.str2float': ('basics.html#str2float', 'fastcore/basics.py'),
'fastcore.basics.str2int': ('basics.html#str2int', 'fastcore/basics.py'),
'fastcore.basics.str2list': ('basics.html#str2list', 'fastcore/basics.py'),
'fastcore.basics.str_enum': ('basics.html#str_enum', 'fastcore/basics.py'),
'fastcore.basics.strcat': ('basics.html#strcat', 'fastcore/basics.py'),
'fastcore.basics.to_bool': ('basics.html#to_bool', 'fastcore/basics.py'),
'fastcore.basics.to_date': ('basics.html#to_date', 'fastcore/basics.py'),
'fastcore.basics.to_float': ('basics.html#to_float', 'fastcore/basics.py'),
'fastcore.basics.to_int': ('basics.html#to_int', 'fastcore/basics.py'),
'fastcore.basics.to_list': ('basics.html#to_list', 'fastcore/basics.py'),
'fastcore.basics.tonull': ('basics.html#tonull', 'fastcore/basics.py'),
'fastcore.basics.true': ('basics.html#true', 'fastcore/basics.py'),
'fastcore.basics.try_attrs': ('basics.html#try_attrs', 'fastcore/basics.py'),
Expand Down
135 changes: 97 additions & 38 deletions fastcore/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,29 @@

# %% auto 0
__all__ = ['defaults', 'null', 'num_methods', 'rnum_methods', 'inum_methods', 'arg0', 'arg1', 'arg2', 'arg3', 'arg4', 'Self',
'ifnone', 'maybe_attr', 'basic_repr', 'BasicRepr', 'is_array', 'listify', 'tuplify', 'true', 'NullType',
'tonull', 'get_class', 'mk_class', 'wrap_class', 'ignore_exceptions', 'exec_local', 'risinstance',
'ver2tuple', 'Inf', 'in_', 'ret_true', 'ret_false', 'stop', 'gen', 'chunked', 'otherwise', 'custom_dir',
'AttrDict', 'AttrDictDefault', 'NS', 'get_annotations_ex', 'eval_type', 'type_hints', 'annotations',
'anno_ret', 'signature_ex', 'union2tuple', 'argnames', 'with_cast', 'store_attr', 'attrdict', 'properties',
'camel2words', 'camel2snake', 'snake2camel', 'class2attr', 'getcallable', 'getattrs', 'hasattrs', 'setattrs',
'try_attrs', 'GetAttrBase', 'GetAttr', 'delegate_attr', 'ShowPrint', 'Int', 'Str', 'Float', 'partition',
'flatten', 'concat', 'strcat', 'detuplify', 'replicate', 'setify', 'merge', 'range_of', 'groupby',
'last_index', 'filter_dict', 'filter_keys', 'filter_values', 'cycle', 'zip_cycle', 'sorted_ex', 'not_',
'argwhere', 'filter_ex', 'renumerate', 'first', 'only', 'nested_attr', 'nested_setdefault',
'type_map', 'ifnone', 'maybe_attr', 'basic_repr', 'BasicRepr', 'is_array', 'listify', 'tuplify', 'true',
'NullType', 'tonull', 'get_class', 'mk_class', 'wrap_class', 'ignore_exceptions', 'exec_local',
'risinstance', 'ver2tuple', 'Inf', 'in_', 'ret_true', 'ret_false', 'stop', 'gen', 'chunked', 'otherwise',
'custom_dir', 'AttrDict', 'AttrDictDefault', 'NS', 'get_annotations_ex', 'eval_type', 'type_hints',
'annotations', 'anno_ret', 'signature_ex', 'union2tuple', 'argnames', 'with_cast', 'store_attr', 'attrdict',
'properties', 'camel2words', 'camel2snake', 'snake2camel', 'class2attr', 'getcallable', 'getattrs',
'hasattrs', 'setattrs', 'try_attrs', 'GetAttrBase', 'GetAttr', 'delegate_attr', 'ShowPrint', 'Int', 'Str',
'Float', 'partition', 'flatten', 'concat', 'strcat', 'detuplify', 'replicate', 'setify', 'merge', 'range_of',
'groupby', 'last_index', 'filter_dict', 'filter_keys', 'filter_values', 'cycle', 'zip_cycle', 'sorted_ex',
'not_', 'argwhere', 'filter_ex', 'renumerate', 'first', 'only', 'nested_attr', 'nested_setdefault',
'nested_callable', 'nested_idx', 'set_nested_idx', 'val2idx', 'uniqueify', 'loop_first_last', 'loop_first',
'loop_last', 'first_match', 'last_match', 'fastuple', 'bind', 'mapt', 'map_ex', 'compose', 'maps',
'partialler', 'instantiate', 'using_attr', 'copy_func', 'patch_to', 'patch', 'patch_property', 'compile_re',
'ImportEnum', 'StrEnum', 'str_enum', 'ValEnum', 'Stateful', 'NotStr', 'PrettyString', 'even_mults',
'num_cpus', 'add_props', 'typed', 'exec_new', 'exec_import', 'str2bool', 'lt', 'gt', 'le', 'ge', 'eq', 'ne',
'num_cpus', 'add_props', 'str2bool', 'str2int', 'str2float', 'str2list', 'str2date', 'to_bool', 'to_int',
'to_float', 'to_list', 'to_date', 'typed', 'exec_new', 'exec_import', 'lt', 'gt', 'le', 'ge', 'eq', 'ne',
'add', 'sub', 'mul', 'truediv', 'is_', 'is_not', 'mod']

# %% ../nbs/01_basics.ipynb
from .imports import *
import builtins,types,typing
import pprint
import ast,builtins,pprint,types,typing
from copy import copy
from datetime import date
try: from types import UnionType
except ImportError: UnionType = None

Expand Down Expand Up @@ -1141,21 +1142,89 @@ def add_props(f, g=None, n=2):
def _typeerr(arg, val, typ): return TypeError(f"{arg}=={val} not {typ}")

# %% ../nbs/01_basics.ipynb
def typed(f):
"Decorator to check param and return types at runtime"
names = f.__code__.co_varnames
anno = annotations(f)
ret = anno.pop('return',None)
def _f(*args,**kwargs):
kw = {**kwargs}
if len(anno) > 0:
for i,arg in enumerate(args): kw[names[i]] = arg
for k,v in kw.items():
if k in anno and not isinstance(v,anno[k]): raise _typeerr(k, v, anno[k])
res = f(*args,**kwargs)
if ret is not None and not isinstance(res,ret): raise _typeerr("return", res, ret)
return res
return functools.update_wrapper(_f, f)
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
if not s: return False
s = s.lower()
if s in ('y', 'yes', 't', 'true', 'on', '1'): return True
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return False
else: raise _typeerr('s', s, 'bool')

# %% ../nbs/01_basics.ipynb
def str2int(s) -> int:
"Convert `s` to an `int`"
s = s.lower().strip()
if s in ('', 'none'): return 0
if s == 'on': return 1
if s == 'off': return 0
return int(s)

# %% ../nbs/01_basics.ipynb
def str2float(s:str):
"Convert `s` to a float"
s = s.lower().strip()
if not s: return 0.0
return float(s)

# %% ../nbs/01_basics.ipynb
def str2list(s:str):
"Convert `s` to a list"
s = s.strip()
if not s: return []
if s[0] != '[': s = '['+s + ']'
return ast.literal_eval(s)

# %% ../nbs/01_basics.ipynb
def str2date(s:str)->date:
"`date.fromisoformat` with empty string handling"
return date.fromisoformat(s) if s else None

# %% ../nbs/01_basics.ipynb
def to_bool(arg): return str2bool(arg) if isinstance(arg, str) else bool(arg)
def to_int(arg): return str2int(arg) if isinstance(arg, str) else int(arg)
def to_float(arg): return str2float(arg) if isinstance(arg, str) else float(arg)
def to_list(arg): return str2list(arg) if isinstance(arg,str) else listify(arg)
def to_date(arg):
if isinstance(arg, str): return str2date(arg)
raise _typeerr('arg', arg, 'date')

type_map = {int: to_int, float: to_float, str: str, bool: to_bool, date: to_date}

# %% ../nbs/01_basics.ipynb
def typed(_func=None, *, cast=False):
"Decorator to check param and return types at runtime, with optional casting"
def decorator(f):
names = f.__code__.co_varnames
anno = annotations(f)
ret = anno.pop('return', None)
def _f(*args, **kwargs):
kw = {**kwargs}
if len(anno) > 0:
for i,arg in enumerate(args): kw[names[i]] = arg
for k,v in kw.items():
if k in anno:
expected_types = tuplify(union2tuple(anno[k]))
if isinstance(v, expected_types): continue
elif cast:
expected_types = listify(filter(lambda x: x is not NoneType, expected_types))
assert not len(expected_types) > 1, "Cannot cast with union types."
# Grab the first type that is not None
expected_type = expected_types[0]
try: kw[k] = type_map.get(expected_type, expected_type)(v)
except (ValueError, TypeError) as e: raise _typeerr(k, v, expected_type) from e
else: raise _typeerr(k, v, expected_types)
res = f(**kw)
if ret is not None:
if isinstance(res, ret): return res
elif cast:
try: res = type_map.get(ret, ret)(res)
except (ValueError, TypeError) as e: raise _typeerr("return", res, ret) from e
else: raise _typeerr("return", res, ret)
return res
return functools.update_wrapper(_f, f)
if _func is None: return decorator # Decorator was called with arguments
else: return decorator(_func) # Decorator was called without arguments

# %% ../nbs/01_basics.ipynb
def exec_new(code):
Expand All @@ -1170,13 +1239,3 @@ def exec_import(mod, sym):
"Import `sym` from `mod` in a new environment"
# pref = '' if __name__=='__main__' or mod[0]=='.' else '.'
return exec_new(f'from {mod} import {sym}')

# %% ../nbs/01_basics.ipynb
def str2bool(s):
"Case-insensitive convert string `s` too a bool (`y`,`yes`,`t`,`true`,`on`,`1`->`True`)"
if not isinstance(s,str): return bool(s)
if not s: return False
s = s.lower()
if s in ('y', 'yes', 't', 'true', 'on', '1'): return True
elif s in ('n', 'no', 'f', 'false', 'off', '0'): return False
else: raise ValueError()
Loading

0 comments on commit f4a05cf

Please sign in to comment.