Skip to content

Commit

Permalink
introduce model_info type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
pearsonca committed Nov 12, 2024
1 parent b1cda6e commit 73c7234
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 40 deletions.
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from . import helpers
from .base import NPIBase

from gempyor import model_info

class MultiPeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf,
modinf : model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from .base import NPIBase

from gempyor import model_info

class SinglePeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf,
modinf : model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
4 changes: 3 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from .base import NPIBase

from gempyor import model_info

debug_print = False

"Cap on # of reduction metadata entries to store in memory"
Expand All @@ -19,7 +21,7 @@ def __init__(
self,
*,
npi_config,
modinf,
modinf : model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
5 changes: 3 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/NPI/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import pyarrow as pa

import gempyor.model_info as model_info

class NPIBase(abc.ABC):
__plugins__ = {}
Expand All @@ -27,7 +28,7 @@ def getReductionDF(self):
def execute(
*,
npi_config,
modinf,
modinf : model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -36,7 +37,7 @@ def execute(
):
"""
npi_config: config of the Modifier we are building, usually a StackedModifiers that will call other NPI
modinf: the ModelInfor class, to inform ti and tf
modinf: the ModelInfo class, to inform ti and tf
modifiers_library: a config bit that contains the other modifiers that could be called by this Modifier. Note
that the confuse library's config resolution mechanism makes slicing the configuration object expensive;
instead give the preloaded settings from .get()
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def simulation_atomic(
*,
snpi_df_in,
hnpi_df_in,
modinf,
modinf : model_info.ModelInfo,
p_draw,
unique_strings,
transition_array,
Expand Down Expand Up @@ -147,7 +147,7 @@ def simulation_atomic(
return outcomes_df


def get_static_arguments(modinf):
def get_static_arguments(modinf : model_info.ModelInfo):
"""
Get the static arguments for the log likelihood function, these are the same for all walkers
"""
Expand Down
16 changes: 9 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/initial_conditions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Dict
import warnings
import os

import numpy as np
import pandas as pd
from numba.typed import Dict
import confuse
import logging

from . import model_info
from .simulation_component import SimulationComponent
from . import utils
from .utils import read_df
import warnings
import os

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(
bool
)

def get_from_config(self, sim_id: int, modinf) -> np.ndarray:
def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> np.ndarray:
method = "Default"
if (
self.initial_conditions_config is not None
Expand Down Expand Up @@ -127,14 +129,14 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray:

return y0

def get_from_file(self, sim_id: int, modinf) -> np.ndarray:
def get_from_file(self, sim_id: int, modinf : model_info.ModelInfo) -> np.ndarray:
return self.get_from_config(sim_id=sim_id, modinf=modinf)


# TODO: rename config to initial_conditions_config as it shadows the global config


def check_population(y0, modinf, ignore_population_checks=False):
def check_population(y0, modinf : model_info.ModelInfo, ignore_population_checks : bool = False):
# check that the inputed values sums to the subpop population:
error = False
for pl_idx, pl in enumerate(modinf.subpop_struct.subpop_names):
Expand All @@ -161,7 +163,7 @@ def check_population(y0, modinf, ignore_population_checks=False):


def read_initial_condition_from_tidydataframe(
ic_df, modinf, allow_missing_subpops, allow_missing_compartments, proportional_ic=False
ic_df, modinf : model_info.ModelInfo, allow_missing_subpops, allow_missing_compartments, proportional_ic=False
):
rests = [] # Places to allocate the rest of the population
y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops))
Expand Down Expand Up @@ -233,7 +235,7 @@ def read_initial_condition_from_tidydataframe(


def read_initial_condition_from_seir_output(
ic_df, modinf, allow_missing_subpops, allow_missing_compartments
ic_df, modinf : model_info.ModelInfo, allow_missing_subpops, allow_missing_compartments
):
"""
Read the initial conditions from the SEIR output.
Expand Down
8 changes: 4 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)


def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1):
def run_parallel_outcomes(modinf : model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1):
start = time.monotonic()

sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots)
Expand Down Expand Up @@ -300,7 +300,7 @@ def read_parameters_from_config(modinf: model_info.ModelInfo):
return parameters


def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi, write=True):
def postprocess_and_write(sim_id, modinf : model_info.ModelInfo, outcomes_df, hpar, npi, write=True):
if write:
modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df)
modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar)
Expand Down Expand Up @@ -337,15 +337,15 @@ def dataframe_from_array(data, subpops, dates, comp_name):
return df


def read_seir_sim(modinf, sim_id):
def read_seir_sim(modinf : model_info.ModelInfo, sim_id):
seir_df = modinf.read_simID(ftype="seir", sim_id=sim_id)

return seir_df


def compute_all_multioutcomes(
*,
modinf,
modinf : model_info.ModelInfo,
sim_id2write,
parameters,
loaded_values=None,
Expand Down
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def plot_single_chain(frompt, ax, chain, label, gt=None):
plt.close(fig)


def plot_fit(modinf, loss):
def plot_fit(modinf : model_info.ModelInfo, loss):
subpop_names = modinf.subpop_struct.subpop_names
fig, axes = plt.subplots(
len(subpop_names),
Expand Down
9 changes: 6 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import pandas as pd
import confuse
import logging

from .simulation_component import SimulationComponent
from . import utils
from . import model_info

import numba as nb
import os

Expand All @@ -15,7 +18,7 @@
## TODO: ideally here path_prefix should not be used and all files loaded from modinf


def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict:
def _DataFrame2NumbaDict(df, amounts, modinf : model_info.ModelInfo) -> nb.typed.Dict:
if not df["date"].is_monotonic_increasing:
raise ValueError(
"_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense"
Expand Down Expand Up @@ -96,7 +99,7 @@ def __init__(self, config: confuse.ConfigView, path_prefix: str = "."):
self.seeding_config = config
self.path_prefix = path_prefix

def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict:
def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> nb.typed.Dict:
method = "NoSeeding"
if self.seeding_config is not None and "method" in self.seeding_config.keys():
method = self.seeding_config["method"].as_str()
Expand Down Expand Up @@ -165,7 +168,7 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict:

return _DataFrame2NumbaDict(df=seeding, amounts=amounts, modinf=modinf)

def get_from_file(self, sim_id: int, modinf) -> nb.typed.Dict:
def get_from_file(self, sim_id: int, modinf : model_info.ModelInfo) -> nb.typed.Dict:
"""only difference with draw seeding is that the sim_id is now sim_id2load"""
return self.get_from_config(sim_id=sim_id, modinf=modinf)

Expand Down
14 changes: 7 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def build_step_source_arg(
modinf,
modinf : model_info.ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -118,7 +118,7 @@ def build_step_source_arg(


def steps_SEIR(
modinf,
modinf : model_info.ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -215,7 +215,7 @@ def steps_SEIR(
return states


def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None):
def build_npi_SEIR(modinf : model_info.ModelInfo, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None):
with Timer("SEIR.NPI"):
loaded_df = None
if bypass_DF is not None:
Expand Down Expand Up @@ -335,7 +335,7 @@ def onerun_SEIR(
return out_df


def run_parallel_SEIR(modinf, config, *, n_jobs=1):
def run_parallel_SEIR(modinf : model_info.ModelInfo, config, *, n_jobs=1):
start = time.monotonic()
sim_ids = np.arange(1, modinf.nslots + 1)

Expand Down Expand Up @@ -364,7 +364,7 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1):
)


def states2Df(modinf, states):
def states2Df(modinf : model_info.ModelInfo, states):
# Tidyup data for R, to save it:
#
# Write output to .snpi.*, .spar.*, and .seir.* files
Expand Down Expand Up @@ -428,7 +428,7 @@ def states2Df(modinf, states):
return out_df


def write_spar_snpi(sim_id, modinf, p_draw, npi):
def write_spar_snpi(sim_id : int, modinf : model_info.ModelInfo, p_draw, npi):
# NPIs
if npi is not None:
modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF())
Expand All @@ -438,7 +438,7 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi):
)


def write_seir(sim_id, modinf, states):
def write_seir(sim_id, modinf : model_info.ModelInfo, states):
# print_disk_diagnosis()
out_df = states2Df(modinf, states)
modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df)
Expand Down
7 changes: 5 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pathlib
import numpy as np
import pandas as pd
import logging, pathlib

import scipy.sparse
import confuse

from .utils import read_df, write_df
import logging, pathlib


logger = logging.getLogger(__name__)
Expand All @@ -13,7 +16,7 @@


class SubpopulationStructure:
def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")):
def __init__(self, *, setup_name : str, subpop_config : confuse.Subview, path_prefix=pathlib.Path(".")):
"""Important attributes:
- self.setup_name: Name of the setup
- self.data: DataFrame with subpopulations and populations
Expand Down
18 changes: 9 additions & 9 deletions flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def command_safe_run(
return sr.returncode, stdout, stderr


def add_method(cls):
def add_method(cls : Any):
"""
A function which adds a function to a class.
Expand Down Expand Up @@ -269,7 +269,7 @@ def wrapper(*args, **kwargs):
return inner


def as_list(thing: any) -> list[any]:
def as_list(thing: Any) -> list[Any]:
"""
Returns argument passed as a list.
Expand Down Expand Up @@ -864,13 +864,13 @@ def get_filetype_for_resume(


def create_resume_file_names_map(
resume_discard_seeding,
flepi_block_index,
resume_run_index,
flepi_prefix,
flepi_slot_index,
flepi_run_index,
last_job_output,
resume_discard_seeding : str,
flepi_block_index : str,
resume_run_index : str,
flepi_prefix : str,
flepi_slot_index : str,
flepi_run_index : str,
last_job_output : str,
) -> dict[str, str]:
"""
Generates a mapping of input file names to output file names for a resume process based on
Expand Down

0 comments on commit 73c7234

Please sign in to comment.