Skip to content

Commit

Permalink
Fix typeguard
Browse files Browse the repository at this point in the history
  • Loading branch information
caneff committed Oct 17, 2023
1 parent 7e94cd6 commit 7b9b666
Showing 1 changed file with 49 additions and 34 deletions.
83 changes: 49 additions & 34 deletions strictly_typed_pandas/typeguard.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,82 @@
import inspect
import typeguard

from typing import Any, Tuple, Union
from strictly_typed_pandas import DataSet, IndexedDataSet


def check_dataset(argname: str, value, expected_type, memo: typeguard.TypeCheckMemo) -> None:
schema_expected = expected_type.__args__[0]
def check_dataset(value: Any, origin_type: Any, args: Tuple[Any, ...], memo: typeguard.TypeCheckMemo) -> None:
schema_expected = args[0]
if not isinstance(value, DataSet):
msg = "Type of {argname} must be a DataSet[{schema_expected}]; got {class_observed} instead"
raise TypeError(
msg.format(
argname=argname,
msg = "Type must be a DataSet[{schema_expected}]; got {class_observed} instead".format(
schema_expected=typeguard.qualified_name(schema_expected),
class_observed=typeguard.qualified_name(value)
)
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(typeguard.TypeCheckError(msg), memo)
else:
raise TypeError(msg)

schema_observed = value.__orig_class__.__args__[0]
if schema_observed != schema_expected:
msg = "Type of {argname} must be a DataSet[{schema_expected}]; got DataSet[{schema_observed}] instead"
raise TypeError(
msg.format(
argname=argname,
msg = "Type must be a DataSet[{schema_expected}]; got DataSet[{schema_observed}] instead".format(
schema_expected=typeguard.qualified_name(schema_expected),
schema_observed=typeguard.qualified_name(schema_observed)
)
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(typeguard.TypeCheckError(msg), memo)
else:
raise TypeError(msg)


def check_indexed_dataset(argname: str, value, expected_type, memo: typeguard.TypeCheckMemo):
schema_index_expected = expected_type.__args__[0]
schema_data_expected = expected_type.__args__[1]
def check_indexed_dataset(value: Any, origin_type: Any, args: Tuple[Any, ...], memo: typeguard.TypeCheckMemo) -> None:
schema_index_expected = args[0]
schema_data_expected = args[1]
if not isinstance(value, IndexedDataSet):
msg = (
"Type of {argname} must be a IndexedDataSet[{schema_index_expected},{schema_data_expected}];" +
"Type must be a IndexedDataSet[{schema_index_expected},{schema_data_expected}];" +
"got {class_observed} instead"
)
raise TypeError(
msg.format(
argname=argname,
).format(
schema_index_expected=typeguard.qualified_name(schema_index_expected),
schema_data_expected=typeguard.qualified_name(schema_data_expected),
class_observed=typeguard.qualified_name(value)
)
)

if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(typeguard.TypeCheckError(msg), memo)
else:
raise TypeError(msg)

schema_index_observed = value.__orig_class__.__args__[0]
schema_data_observed = value.__orig_class__.__args__[1]
if schema_index_observed != schema_index_expected or schema_data_observed != schema_data_expected:
msg = (
"Type of {argname} must be a IndexedDataSet[{schema_index_expected},{schema_data_expected}];" +
"Type must be a IndexedDataSet[{schema_index_expected},{schema_data_expected}];" +
"got IndexedDataSet[{schema_index_observed},{schema_data_observed}] instead"
).format(
schema_index_expected=typeguard.qualified_name(schema_index_expected),
schema_data_expected=typeguard.qualified_name(schema_data_expected),
schema_index_observed=typeguard.qualified_name(schema_index_observed),
schema_data_observed=typeguard.qualified_name(schema_data_observed)
)
raise TypeError(
msg.format(
argname=argname,
schema_index_expected=typeguard.qualified_name(schema_index_expected),
schema_data_expected=typeguard.qualified_name(schema_data_expected),
schema_index_observed=typeguard.qualified_name(schema_index_observed),
schema_data_observed=typeguard.qualified_name(schema_data_observed)
)
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(typeguard.TypeCheckError(msg), memo)
else:
raise TypeError(msg)


def check_dataset_lookup(origin_type: Any,
args: Tuple[Any, ...], extras: Tuple[Any, ...]) -> Union[typeguard.TypeCheckerCallable, None]:

if not inspect.isclass(origin_type):
return None

if issubclass(origin_type, DataSet):
return check_dataset
if issubclass(origin_type, IndexedDataSet):
return check_indexed_dataset

return None


typeguard._checkers.origin_type_checkers[DataSet] = check_dataset
typeguard._checkers.origin_type_checkers[IndexedDataSet] = check_indexed_dataset
typeguard.checker_lookup_functions.append(check_dataset_lookup)
typechecked = typeguard.typechecked

0 comments on commit 7b9b666

Please sign in to comment.