Skip to content

Commit

Permalink
Merge pull request #3 from AllenInstitute/fix-core-typing
Browse files Browse the repository at this point in the history
Fix core typing
  • Loading branch information
njmei authored May 16, 2024
2 parents a4501de + a05f5c4 commit 8a43297
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 40 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ no_site_packages = false

# Untyped definitions and calls
# https://mypy.readthedocs.io/en/stable/config_file.html#untyped-definitions-and-calls
# TODO: enable and fix errors
check_untyped_defs = false
check_untyped_defs = true

# Miscellaneous strictness flags
# https://mypy.readthedocs.io/en/latest/config_file.html#miscellaneous-strictness-flags
Expand All @@ -170,7 +169,7 @@ show_absolute_path = false

# None and Optional handling
# https://mypy.readthedocs.io/en/latest/config_file.html#none-and-optional-handling
strict_optional = false
strict_optional = true

# Miscellaneous
# https://mypy.readthedocs.io/en/latest/config_file.html#miscellaneous
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __post_init__(self, *args, **kwargs):
"""Default __post_init__ method. Safe parent __post_init__ method calls"""

try:
post_init = super().__post_init__ # type: ignore[attr]
post_init = super().__post_init__ # type: ignore[misc]
except AttributeError:
pass
else:
Expand Down
8 changes: 4 additions & 4 deletions src/aibs_informatics_core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def load_env_type__from_env(cls) -> EnvType:
return EnvType(env_type)

@classmethod
def load_env_label__from_env(cls) -> str:
def load_env_label__from_env(cls) -> Optional[str]:
return get_env_var(ENV_LABEL_KEY, ENV_LABEL_KEY_ALIAS, LABEL_KEY, LABEL_KEY_ALIAS)


Expand All @@ -197,7 +197,7 @@ def env_base(self) -> EnvBase:
try:
return self._env_base
except AttributeError:
self.env_base = None
self.env_base = None # type: ignore[assignment]
return self._env_base

@env_base.setter
Expand All @@ -213,8 +213,8 @@ def env_base(self, env_base: Optional[EnvBase] = None):


class EnvBaseEnumMixins:
def prefix_with(self, env_base: EnvBase = None) -> str:
env_base = get_env_base(env_base)
def prefix_with(self, env_base: Optional[EnvBase] = None) -> str:
env_base: EnvBase = get_env_base(env_base)
if isinstance(self, Enum):
return env_base.prefixed(self.value)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_core/executors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def write_output__file(cls, output: JSON, local_path: Path) -> None:
"""

if local_path.is_dir() or (
local_path.parent.exists and not os.access(local_path.parent, os.W_OK)
local_path.parent.exists() and not os.access(local_path.parent, os.W_OK)
):
raise ValueError(
f"local path specified {local_path} cannot be written to. " f"Must be a file "
Expand Down
10 changes: 7 additions & 3 deletions src/aibs_informatics_core/models/api/http_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ def merged_params(self) -> Dict[str, JSON]:
return request_json

@classmethod
def from_stringified_route_params(cls, parameters: Dict[str, str]) -> Dict[str, JSON]:
def from_stringified_route_params(
cls, parameters: Optional[Dict[str, str]]
) -> Dict[str, JSON]:
evaluated_params = dict()

input_params = parameters or dict()
# literal_eval as much as we can to python primitives, our schema will take care of rest.
for k, v in parameters.items():
for k, v in input_params.items():
try:
evaluated_params[k] = ast.literal_eval(urllib.parse.unquote(v))
except Exception:
Expand Down Expand Up @@ -94,7 +98,7 @@ def to_stringified_query_params(cls, parameters: Optional[Dict[str, JSON]]) -> D
return {QUERY_PARAMS_KEY: urlsafe_b64encode(parameters_str.encode()).decode()}

@classmethod
def to_stringified_request_body(cls, parameters: Optional[Dict[str, JSON]]) -> str:
def to_stringified_request_body(cls, parameters: Optional[Dict[str, JSON]]) -> Optional[str]:
if parameters is None or len(parameters) == 0:
return None
return json.dumps(parameters, sort_keys=True)
Expand Down
6 changes: 0 additions & 6 deletions src/aibs_informatics_core/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def from_path(cls: Type[T], path: Path, **kwargs) -> T:
def to_path(self, path: Path, **kwargs):
... # pragma: no cover

def copy(self: T, **kwargs) -> T:
... # pragma: no cover


# --------------------------------------------------------------
# BaseModel ABC
Expand Down Expand Up @@ -130,9 +127,6 @@ def from_path(cls: Type[M], path: Path, **kwargs) -> M:
def to_path(self, path: Path, **kwargs):
path.write_text(self.to_json(**kwargs))

def copy(self: M, **kwargs) -> M:
return self.from_dict(self.to_dict(**kwargs), **kwargs)


# --------------------------------------------------------------
# DataClassModel
Expand Down
8 changes: 7 additions & 1 deletion src/aibs_informatics_core/models/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,13 @@ def __new__(cls, *values):
obj._sort_key_name = values[2]
obj._index_name = values[3]
obj._attributes = values[4] if len(values) > 4 else None
obj._all_values = tuple(values)
obj._all_values = (
obj._value_,
obj._key_name,
obj._sort_key_name,
obj._index_name,
obj._attributes,
)
return obj

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, FrozenSet, Iterable, List, Optional
from typing import Dict, FrozenSet, Iterable, List, Optional, Union

from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.models.base import (
Expand Down Expand Up @@ -105,7 +105,7 @@ def from_pairs(cls, *pairs: ParamPair) -> List[ParamSetPair]:
List[ParamSetPair]: A list of ParamSetPairs
"""
# group pairs only by outputs
output_set_pairs: Dict[str, ParamSetPair] = {}
output_set_pairs: Dict[Union[str, None], ParamSetPair] = {}
for pair in pairs:
if pair.output not in output_set_pairs:
output_set_pairs[pair.output] = ParamSetPair(
Expand Down Expand Up @@ -210,7 +210,7 @@ def from_pairs(cls, *pairs: JobParamPair) -> List[JobParamSetPair]:
List[JobParamSetPair]: A list of JobParamSetPairs
"""
# group pairs only by outputs
output_set_pairs: Dict[ResolvableJobParam, JobParamSetPair] = {}
output_set_pairs: Dict[Union[ResolvableJobParam, None], JobParamSetPair] = {}
for pair in pairs:
if pair.output not in output_set_pairs:
output_set_pairs[pair.output] = JobParamSetPair(
Expand Down
13 changes: 7 additions & 6 deletions src/aibs_informatics_core/models/demand_execution/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,12 @@ def _validate_param_pairs(self):

# Validate no duplicate output sets
seen = set()
duplicate_output_sets = [
s.outputs
for s in self.param_set_pairs
if s.outputs in seen or (s.outputs and seen.add(s.outputs))
]
duplicate_output_sets = []
for s in self.param_set_pairs:
if s.outputs in seen:
duplicate_output_sets.append(s.outputs)
if s.outputs:
seen.add(s.outputs)
if len(duplicate_output_sets) > 0:
raise ValidationError(
"Duplicate output set(s) in input_output_map: " f"{duplicate_output_sets}"
Expand Down Expand Up @@ -234,7 +235,7 @@ def param_set_pairs(self) -> List[ParamSetPair]:
param_set_pairs: List[ParamSetPair] = []
param_pairs: List[ParamPair] = []
if self.param_pair_overrides:
seen_outputs: Set[str] = set()
seen_outputs: Set[Union[str, None]] = set()
for pair in self.param_pair_overrides:
if isinstance(pair, ParamSetPair):
seen_outputs.update(pair.outputs)
Expand Down
1 change: 1 addition & 0 deletions src/aibs_informatics_core/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def f_retry(*args, **kwargs):
[callback(ex) for callback in applicable_callbacks]
):
raise ex
assert logger is not None
logger.warning("%s, Retrying in %d seconds..." % (str(ex), mdelay))
time.sleep(mdelay)
mtries -= 1
Expand Down
8 changes: 4 additions & 4 deletions src/aibs_informatics_core/utils/os_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def env_var_overrides(*env_vars: EnvVarItemType):
try:
yield
finally:
for key, value in original_env_vars.items():
if value is None:
del os.environ[key]
for k, v in original_env_vars.items():
if v is None:
del os.environ[k]
else:
os.environ[key] = value
os.environ[k] = v
8 changes: 0 additions & 8 deletions test/aibs_informatics_core/models/base/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,6 @@ def test__BaseModel__to_json__from_json():
assert new_model.a_str == "I'm a string!"


def test__BaseModel__copy():
model = SimpleBaseModel(a_str="I'm a string!", a_int=42)
new_model = model.copy()
assert new_model.a_int == 42
assert new_model.a_str == "I'm a string!"
assert new_model is not model


# ----------------------------------------------------------
# DataClassModel tests
# ----------------------------------------------------------
Expand Down

0 comments on commit 8a43297

Please sign in to comment.