From 7a57dc5bb8d072f36fc6553e829922c9930a6b31 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:01:37 -0500 Subject: [PATCH] Add `ModelInfo` type hints Per @pearsonca request fixed GH-396. Cannot type hint the `ModelInfo` objects in `gempyor.initial_conditions/seeding` due to circular imports. --- .../src/gempyor/NPI/MultiPeriodModifier.py | 6 ++- .../src/gempyor/NPI/SinglePeriodModifier.py | 7 +-- .../src/gempyor/NPI/StackedModifier.py | 6 ++- flepimop/gempyor_pkg/src/gempyor/NPI/base.py | 7 +-- flepimop/gempyor_pkg/src/gempyor/inference.py | 48 +++++-------------- .../gempyor_pkg/src/gempyor/model_info.py | 1 + flepimop/gempyor_pkg/src/gempyor/outcomes.py | 24 ++++++---- .../src/gempyor/postprocess_inference.py | 41 +++------------- flepimop/gempyor_pkg/src/gempyor/seir.py | 32 ++++++++----- .../src/gempyor/subpopulation_structure.py | 13 +++-- flepimop/gempyor_pkg/src/gempyor/utils.py | 20 ++++---- 11 files changed, 90 insertions(+), 115 deletions(-) diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py index 84414e7ff..199ff90d1 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py @@ -1,7 +1,9 @@ -import pandas as pd import numpy as np +import pandas as pd + from . import helpers from .base import NPIBase +from ..model_info import ModelInfo class MultiPeriodModifier(NPIBase): @@ -9,7 +11,7 @@ def __init__( self, *, npi_config, - modinf, + modinf: ModelInfo, modifiers_library, subpops, loaded_df=None, diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py index cdda3c4b9..22ab2310b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py @@ -1,8 +1,9 @@ -import pandas as pd import numpy as np -from . import helpers +import pandas as pd +from . import helpers from .base import NPIBase +from ..model_info import ModelInfo class SinglePeriodModifier(NPIBase): @@ -10,7 +11,7 @@ def __init__( self, *, npi_config, - modinf, + modinf: ModelInfo, modifiers_library, subpops, loaded_df=None, diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py index 6cf178735..91df48cac 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py @@ -1,11 +1,13 @@ import collections +import os import warnings import confuse import pandas as pd -import os from .base import NPIBase +from ..model_info import ModelInfo + debug_print = False @@ -19,7 +21,7 @@ def __init__( self, *, npi_config, - modinf, + modinf: ModelInfo, modifiers_library, subpops, loaded_df=None, diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/base.py b/flepimop/gempyor_pkg/src/gempyor/NPI/base.py index 8fd4ed6d6..0f8a08863 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/base.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/base.py @@ -1,5 +1,6 @@ import abc -import pyarrow as pa + +from ..model_info import ModelInfo class NPIBase(abc.ABC): @@ -27,7 +28,7 @@ def getReductionDF(self): def execute( *, npi_config, - modinf, + modinf: ModelInfo, modifiers_library, subpops, loaded_df=None, @@ -36,7 +37,7 @@ def execute( ): """ npi_config: config of the Modifier we are building, usually a StackedModifiers that will call other NPI - modinf: the ModelInfor class, to inform ti and tf + modinf: the ModelInfo class, to inform ti and tf modifiers_library: a config bit that contains the other modifiers that could be called by this Modifier. Note that the confuse library's config resolution mechanism makes slicing the configuration object expensive; instead give the preloaded settings from .get() diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index 823ba52b3..fc8b6c66d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -9,53 +9,29 @@ # function terminated successfully -from . import seir, model_info -from . import outcomes, file_paths -from .utils import config, Timer, read_df, as_list -import numpy as np from concurrent.futures import ProcessPoolExecutor - -# Logger configuration +import copy import logging -import os import multiprocessing as mp +import os + +import numba as nb +import numpy as np import pandas as pd import pyarrow.parquet as pq import xarray as xr -import numba as nb + +from . import seir, model_info +from . import outcomes, file_paths +from .utils import config, Timer, read_df, as_list logging.basicConfig(level=os.environ.get("FLEPI_LOGLEVEL", "INFO").upper()) logger = logging.getLogger() handler = logging.StreamHandler() -# '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' -formatter = logging.Formatter( - " %(name)s :: %(levelname)-8s :: %(message)s" - # "%(asctime)s [%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" -) - +formatter = logging.Formatter(" %(name)s :: %(levelname)-8s :: %(message)s") handler.setFormatter(formatter) -# logger.addHandler(handler) -from . import seir, model_info -from . import outcomes -from .utils import config, Timer, read_df -import numpy as np -from concurrent.futures import ProcessPoolExecutor - -# Logger configuration -import logging -import os -import multiprocessing as mp -import pandas as pd -import pyarrow.parquet as pq -import xarray as xr -import numba as nb -import copy -import matplotlib.pyplot as plt -import seaborn as sns -import confuse -from . import inference_parameter, logloss, statistics # TODO: should be able to draw e.g from an initial condition folder buuut keep the draw as a blob # so it is saved by emcee, so I can build a posterio @@ -66,7 +42,7 @@ def simulation_atomic( *, snpi_df_in, hnpi_df_in, - modinf, + modinf: model_info.ModelInfo, p_draw, unique_strings, transition_array, @@ -147,7 +123,7 @@ def simulation_atomic( return outcomes_df -def get_static_arguments(modinf): +def get_static_arguments(modinf: model_info.ModelInfo): """ Get the static arguments for the log likelihood function, these are the same for all walkers """ diff --git a/flepimop/gempyor_pkg/src/gempyor/model_info.py b/flepimop/gempyor_pkg/src/gempyor/model_info.py index 35502aadf..e1ee32a9d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/model_info.py +++ b/flepimop/gempyor_pkg/src/gempyor/model_info.py @@ -10,6 +10,7 @@ ) from .utils import read_df, write_df + logger = logging.getLogger(__name__) diff --git a/flepimop/gempyor_pkg/src/gempyor/outcomes.py b/flepimop/gempyor_pkg/src/gempyor/outcomes.py index e5f09c1d2..4c95d0fb7 100644 --- a/flepimop/gempyor_pkg/src/gempyor/outcomes.py +++ b/flepimop/gempyor_pkg/src/gempyor/outcomes.py @@ -1,22 +1,24 @@ import itertools -import time, random +import logging +import time + from numba import jit -import xarray as xr import numpy as np import pandas as pd +import pyarrow as pa import tqdm.contrib.concurrent +import xarray as xr + from .utils import config, Timer, read_df -import pyarrow as pa -import pandas as pd from . import NPI, model_info -import logging - logger = logging.getLogger(__name__) -def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1): +def run_parallel_outcomes( + modinf: model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1 +): start = time.monotonic() sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots) @@ -300,7 +302,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): return parameters -def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi, write=True): +def postprocess_and_write( + sim_id, modinf: model_info.ModelInfo, outcomes_df, hpar, npi, write=True +): if write: modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df) modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar) @@ -337,7 +341,7 @@ def dataframe_from_array(data, subpops, dates, comp_name): return df -def read_seir_sim(modinf, sim_id): +def read_seir_sim(modinf: model_info.ModelInfo, sim_id): seir_df = modinf.read_simID(ftype="seir", sim_id=sim_id) return seir_df @@ -345,7 +349,7 @@ def read_seir_sim(modinf, sim_id): def compute_all_multioutcomes( *, - modinf, + modinf: model_info.ModelInfo, sim_id2write, parameters, loaded_values=None, diff --git a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py index a98c74e7f..8e79fd90f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py @@ -1,39 +1,10 @@ -import gempyor -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -from pathlib import Path -import copy - -# import seaborn as sns -import matplotlib._color_data as mcd -import pyarrow.parquet as pq -import subprocess -import dask.dataframe as dd -import matplotlib.dates as mdates -import matplotlib.cbook as cbook from matplotlib.backends.backend_pdf import PdfPages -from gempyor.utils import config, as_list -import os -import multiprocessing as mp -import pandas as pd -import pyarrow.parquet as pq -import xarray as xr -from gempyor import ( - config, - model_info, - outcomes, - seir, - inference_parameter, - logloss, - inference, -) -from gempyor.inference import GempyorInference -import tqdm -import os -from multiprocessing import cpu_count +import matplotlib.pyplot as plt +import numpy as np import seaborn as sns +import tqdm + +from .model_info import ModelInfo def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): @@ -145,7 +116,7 @@ def plot_single_chain(frompt, ax, chain, label, gt=None): plt.close(fig) -def plot_fit(modinf, loss): +def plot_fit(modinf: ModelInfo, loss): subpop_names = modinf.subpop_struct.subpop_names fig, axes = plt.subplots( len(subpop_names), diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index 5ea236c98..60c48bf24 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -1,20 +1,23 @@ import itertools +import logging import time + import numpy as np import pandas as pd import scipy import tqdm.contrib.concurrent import xarray as xr -from . import NPI, model_info, steps_rk4 -from .utils import Timer, print_disk_diagnosis, read_df -import logging +from . import NPI, steps_rk4 +from .model_info import ModelInfo +from .utils import Timer, read_df + logger = logging.getLogger(__name__) def build_step_source_arg( - modinf, + modinf: ModelInfo, parsed_parameters, transition_array, proportion_array, @@ -118,7 +121,7 @@ def build_step_source_arg( def steps_SEIR( - modinf, + modinf: ModelInfo, parsed_parameters, transition_array, proportion_array, @@ -215,7 +218,14 @@ def steps_SEIR( return states -def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None): +def build_npi_SEIR( + modinf: ModelInfo, + load_ID, + sim_id2load, + config, + bypass_DF=None, + bypass_FN=None, +): with Timer("SEIR.NPI"): loaded_df = None if bypass_DF is not None: @@ -257,7 +267,7 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ def onerun_SEIR( sim_id2write: int, - modinf: model_info.ModelInfo, + modinf: ModelInfo, load_ID: bool = False, sim_id2load: int = None, config=None, @@ -335,7 +345,7 @@ def onerun_SEIR( return out_df -def run_parallel_SEIR(modinf, config, *, n_jobs=1): +def run_parallel_SEIR(modinf: ModelInfo, config, *, n_jobs=1): start = time.monotonic() sim_ids = np.arange(1, modinf.nslots + 1) @@ -364,7 +374,7 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): ) -def states2Df(modinf, states): +def states2Df(modinf: ModelInfo, states): # Tidyup data for R, to save it: # # Write output to .snpi.*, .spar.*, and .seir.* files @@ -428,7 +438,7 @@ def states2Df(modinf, states): return out_df -def write_spar_snpi(sim_id, modinf, p_draw, npi): +def write_spar_snpi(sim_id: int, modinf: ModelInfo, p_draw, npi): # NPIs if npi is not None: modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF()) @@ -438,7 +448,7 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi): ) -def write_seir(sim_id, modinf, states): +def write_seir(sim_id, modinf: ModelInfo, states): # print_disk_diagnosis() out_df = states2Df(modinf, states) modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df) diff --git a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py index 7c46a9960..d4b729ddb 100644 --- a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py +++ b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py @@ -1,9 +1,10 @@ +import logging import pathlib + +import confuse import numpy as np import pandas as pd import scipy.sparse -from .utils import read_df, write_df -import logging, pathlib logger = logging.getLogger(__name__) @@ -13,7 +14,13 @@ class SubpopulationStructure: - def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): + def __init__( + self, + *, + setup_name: str, + subpop_config: confuse.Subview, + path_prefix=pathlib.Path("."), + ): """Important attributes: - self.setup_name: Name of the setup - self.data: DataFrame with subpopulations and populations diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 377dcdee5..a240be71f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -18,7 +18,7 @@ import scipy.stats import sympy.parsing.sympy_parser -from gempyor import file_paths +from . import file_paths logger = logging.getLogger(__name__) @@ -141,7 +141,7 @@ def command_safe_run( return sr.returncode, stdout, stderr -def add_method(cls): +def add_method(cls: Any): """ A function which adds a function to a class. @@ -269,7 +269,7 @@ def wrapper(*args, **kwargs): return inner -def as_list(thing: any) -> list[any]: +def as_list(thing: Any) -> list[Any]: """ Returns argument passed as a list. @@ -864,13 +864,13 @@ def get_filetype_for_resume( def create_resume_file_names_map( - resume_discard_seeding, - flepi_block_index, - resume_run_index, - flepi_prefix, - flepi_slot_index, - flepi_run_index, - last_job_output, + resume_discard_seeding: str, + flepi_block_index: str, + resume_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + flepi_run_index: str, + last_job_output: str, ) -> dict[str, str]: """ Generates a mapping of input file names to output file names for a resume process based on