Skip to content

Commit

Permalink
Add ModelInfo type hints
Browse files Browse the repository at this point in the history
Per @pearsonca request fixed GH-396. Cannot type hint the `ModelInfo`
objects in `gempyor.initial_conditions/seeding` due to circular imports.
  • Loading branch information
TimothyWillard authored and pearsonca committed Nov 13, 2024
1 parent b1cda6e commit 7a57dc5
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 115 deletions.
6 changes: 4 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import pandas as pd
import numpy as np
import pandas as pd

from . import helpers
from .base import NPIBase
from ..model_info import ModelInfo


class MultiPeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf,
modinf: ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
7 changes: 4 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import pandas as pd
import numpy as np
from . import helpers
import pandas as pd

from . import helpers
from .base import NPIBase
from ..model_info import ModelInfo


class SinglePeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf,
modinf: ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
6 changes: 4 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections
import os
import warnings

import confuse
import pandas as pd
import os

from .base import NPIBase
from ..model_info import ModelInfo


debug_print = False

Expand All @@ -19,7 +21,7 @@ def __init__(
self,
*,
npi_config,
modinf,
modinf: ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
7 changes: 4 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/NPI/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import pyarrow as pa

from ..model_info import ModelInfo


class NPIBase(abc.ABC):
Expand Down Expand Up @@ -27,7 +28,7 @@ def getReductionDF(self):
def execute(
*,
npi_config,
modinf,
modinf: ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -36,7 +37,7 @@ def execute(
):
"""
npi_config: config of the Modifier we are building, usually a StackedModifiers that will call other NPI
modinf: the ModelInfor class, to inform ti and tf
modinf: the ModelInfo class, to inform ti and tf
modifiers_library: a config bit that contains the other modifiers that could be called by this Modifier. Note
that the confuse library's config resolution mechanism makes slicing the configuration object expensive;
instead give the preloaded settings from .get()
Expand Down
48 changes: 12 additions & 36 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,29 @@
# function terminated successfully


from . import seir, model_info
from . import outcomes, file_paths
from .utils import config, Timer, read_df, as_list
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# Logger configuration
import copy
import logging
import os
import multiprocessing as mp
import os

import numba as nb
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
import numba as nb

from . import seir, model_info
from . import outcomes, file_paths
from .utils import config, Timer, read_df, as_list


logging.basicConfig(level=os.environ.get("FLEPI_LOGLEVEL", "INFO").upper())
logger = logging.getLogger()
handler = logging.StreamHandler()
# '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
formatter = logging.Formatter(
" %(name)s :: %(levelname)-8s :: %(message)s"
# "%(asctime)s [%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
)

formatter = logging.Formatter(" %(name)s :: %(levelname)-8s :: %(message)s")
handler.setFormatter(formatter)
# logger.addHandler(handler)

from . import seir, model_info
from . import outcomes
from .utils import config, Timer, read_df
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# Logger configuration
import logging
import os
import multiprocessing as mp
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
import numba as nb
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import confuse
from . import inference_parameter, logloss, statistics

# TODO: should be able to draw e.g from an initial condition folder buuut keep the draw as a blob
# so it is saved by emcee, so I can build a posterio
Expand All @@ -66,7 +42,7 @@ def simulation_atomic(
*,
snpi_df_in,
hnpi_df_in,
modinf,
modinf: model_info.ModelInfo,
p_draw,
unique_strings,
transition_array,
Expand Down Expand Up @@ -147,7 +123,7 @@ def simulation_atomic(
return outcomes_df


def get_static_arguments(modinf):
def get_static_arguments(modinf: model_info.ModelInfo):
"""
Get the static arguments for the log likelihood function, these are the same for all walkers
"""
Expand Down
1 change: 1 addition & 0 deletions flepimop/gempyor_pkg/src/gempyor/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from .utils import read_df, write_df


logger = logging.getLogger(__name__)


Expand Down
24 changes: 14 additions & 10 deletions flepimop/gempyor_pkg/src/gempyor/outcomes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import itertools
import time, random
import logging
import time

from numba import jit
import xarray as xr
import numpy as np
import pandas as pd
import pyarrow as pa
import tqdm.contrib.concurrent
import xarray as xr

from .utils import config, Timer, read_df
import pyarrow as pa
import pandas as pd
from . import NPI, model_info


import logging

logger = logging.getLogger(__name__)


def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1):
def run_parallel_outcomes(
modinf: model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1
):
start = time.monotonic()

sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots)
Expand Down Expand Up @@ -300,7 +302,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo):
return parameters


def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi, write=True):
def postprocess_and_write(
sim_id, modinf: model_info.ModelInfo, outcomes_df, hpar, npi, write=True
):
if write:
modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df)
modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar)
Expand Down Expand Up @@ -337,15 +341,15 @@ def dataframe_from_array(data, subpops, dates, comp_name):
return df


def read_seir_sim(modinf, sim_id):
def read_seir_sim(modinf: model_info.ModelInfo, sim_id):
seir_df = modinf.read_simID(ftype="seir", sim_id=sim_id)

return seir_df


def compute_all_multioutcomes(
*,
modinf,
modinf: model_info.ModelInfo,
sim_id2write,
parameters,
loaded_values=None,
Expand Down
41 changes: 6 additions & 35 deletions flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,10 @@
import gempyor
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
import copy

# import seaborn as sns
import matplotlib._color_data as mcd
import pyarrow.parquet as pq
import subprocess
import dask.dataframe as dd
import matplotlib.dates as mdates
import matplotlib.cbook as cbook
from matplotlib.backends.backend_pdf import PdfPages
from gempyor.utils import config, as_list
import os
import multiprocessing as mp
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
from gempyor import (
config,
model_info,
outcomes,
seir,
inference_parameter,
logloss,
inference,
)
from gempyor.inference import GempyorInference
import tqdm
import os
from multiprocessing import cpu_count
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tqdm

from .model_info import ModelInfo


def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin):
Expand Down Expand Up @@ -145,7 +116,7 @@ def plot_single_chain(frompt, ax, chain, label, gt=None):
plt.close(fig)


def plot_fit(modinf, loss):
def plot_fit(modinf: ModelInfo, loss):
subpop_names = modinf.subpop_struct.subpop_names
fig, axes = plt.subplots(
len(subpop_names),
Expand Down
32 changes: 21 additions & 11 deletions flepimop/gempyor_pkg/src/gempyor/seir.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import itertools
import logging
import time

import numpy as np
import pandas as pd
import scipy
import tqdm.contrib.concurrent
import xarray as xr

from . import NPI, model_info, steps_rk4
from .utils import Timer, print_disk_diagnosis, read_df
import logging
from . import NPI, steps_rk4
from .model_info import ModelInfo
from .utils import Timer, read_df


logger = logging.getLogger(__name__)


def build_step_source_arg(
modinf,
modinf: ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -118,7 +121,7 @@ def build_step_source_arg(


def steps_SEIR(
modinf,
modinf: ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -215,7 +218,14 @@ def steps_SEIR(
return states


def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None):
def build_npi_SEIR(
modinf: ModelInfo,
load_ID,
sim_id2load,
config,
bypass_DF=None,
bypass_FN=None,
):
with Timer("SEIR.NPI"):
loaded_df = None
if bypass_DF is not None:
Expand Down Expand Up @@ -257,7 +267,7 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_

def onerun_SEIR(
sim_id2write: int,
modinf: model_info.ModelInfo,
modinf: ModelInfo,
load_ID: bool = False,
sim_id2load: int = None,
config=None,
Expand Down Expand Up @@ -335,7 +345,7 @@ def onerun_SEIR(
return out_df


def run_parallel_SEIR(modinf, config, *, n_jobs=1):
def run_parallel_SEIR(modinf: ModelInfo, config, *, n_jobs=1):
start = time.monotonic()
sim_ids = np.arange(1, modinf.nslots + 1)

Expand Down Expand Up @@ -364,7 +374,7 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1):
)


def states2Df(modinf, states):
def states2Df(modinf: ModelInfo, states):
# Tidyup data for R, to save it:
#
# Write output to .snpi.*, .spar.*, and .seir.* files
Expand Down Expand Up @@ -428,7 +438,7 @@ def states2Df(modinf, states):
return out_df


def write_spar_snpi(sim_id, modinf, p_draw, npi):
def write_spar_snpi(sim_id: int, modinf: ModelInfo, p_draw, npi):
# NPIs
if npi is not None:
modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF())
Expand All @@ -438,7 +448,7 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi):
)


def write_seir(sim_id, modinf, states):
def write_seir(sim_id, modinf: ModelInfo, states):
# print_disk_diagnosis()
out_df = states2Df(modinf, states)
modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df)
Expand Down
Loading

0 comments on commit 7a57dc5

Please sign in to comment.