Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce model_info type annotation #396

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pandas as pd
import datetime

import numpy as np
import pandas as pd

from . import helpers
from .base import NPIBase

Expand All @@ -9,7 +12,8 @@ def __init__(
self,
*,
npi_config,
modinf,
modinf_ti: datetime.date,
modinf_tf: datetime.date,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -27,8 +31,8 @@ def __init__(
)

self.sanitize = sanitize
self.start_date = modinf.ti
self.end_date = modinf.tf
self.start_date = modinf_ti
self.end_date = modinf_tf

self.subpops = subpops

Expand Down
13 changes: 8 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pandas as pd
import datetime

import numpy as np
from . import helpers
import pandas as pd

from . import helpers
from .base import NPIBase


Expand All @@ -10,7 +12,8 @@ def __init__(
self,
*,
npi_config,
modinf,
modinf_ti: datetime.date,
modinf_tf: datetime.date,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -26,8 +29,8 @@ def __init__(
)
)

self.start_date = modinf.ti
self.end_date = modinf.tf
self.start_date = modinf_ti
self.end_date = modinf_tf

self.pnames_overlap_operation_sum = pnames_overlap_operation_sum
self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod
Expand Down
13 changes: 8 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import collections
import os
import warnings
import datetime

import confuse
import pandas as pd
import os

from .base import NPIBase

Expand All @@ -19,7 +20,8 @@ def __init__(
self,
*,
npi_config,
modinf,
modinf_ti: datetime.date,
modinf_tf: datetime.date,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -28,8 +30,8 @@ def __init__(
):
super().__init__(name=npi_config.name)

self.start_date = modinf.ti
self.end_date = modinf.tf
self.start_date = modinf_ti
self.end_date = modinf_tf

self.pnames_overlap_operation_sum = pnames_overlap_operation_sum
self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod
Expand Down Expand Up @@ -61,7 +63,8 @@ def __init__(

sub_npi = NPIBase.execute(
npi_config=scenario_npi_config,
modinf=modinf,
modinf_ti=modinf_ti,
modinf_tf=modinf_tf,
modifiers_library=modifiers_library,
subpops=subpops,
loaded_df=loaded_df,
Expand Down
10 changes: 6 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/NPI/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import pyarrow as pa
import datetime


class NPIBase(abc.ABC):
Expand Down Expand Up @@ -27,7 +27,8 @@ def getReductionDF(self):
def execute(
*,
npi_config,
modinf,
modinf_ti: datetime.date,
modinf_tf: datetime.date,
modifiers_library,
subpops,
loaded_df=None,
Expand All @@ -36,7 +37,7 @@ def execute(
):
"""
npi_config: config of the Modifier we are building, usually a StackedModifiers that will call other NPI
modinf: the ModelInfor class, to inform ti and tf
modinf: the ModelInfo class, to inform ti and tf
pearsonca marked this conversation as resolved.
Show resolved Hide resolved
modifiers_library: a config bit that contains the other modifiers that could be called by this Modifier. Note
that the confuse library's config resolution mechanism makes slicing the configuration object expensive;
instead give the preloaded settings from .get()
Expand All @@ -45,7 +46,8 @@ def execute(
npi_class = NPIBase.__plugins__[method]
return npi_class(
npi_config=npi_config,
modinf=modinf,
modinf_ti=modinf_ti,
modinf_tf=modinf_tf,
modifiers_library=modifiers_library,
subpops=subpops,
loaded_df=loaded_df,
Expand Down
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@

npi = NPI.NPIBase.execute(
npi_config=modinf.npi_config_seir,
modinf=modinf,
modinf_ti=modinf.ti,
modinf_tf=modinf.tf,
modifiers_library=modinf.seir_modifiers_library,
subpops=modinf.subpop_struct.subpop_names,
pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"],
Expand Down
48 changes: 12 additions & 36 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,29 @@
# function terminated successfully


from . import seir, model_info
from . import outcomes, file_paths
from .utils import config, Timer, read_df, as_list
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# Logger configuration
import copy
import logging
import os
import multiprocessing as mp
import os

import numba as nb
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
import numba as nb

from . import seir, model_info
from . import outcomes, file_paths
from .utils import config, Timer, read_df, as_list


logging.basicConfig(level=os.environ.get("FLEPI_LOGLEVEL", "INFO").upper())
logger = logging.getLogger()
handler = logging.StreamHandler()
# '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
formatter = logging.Formatter(
" %(name)s :: %(levelname)-8s :: %(message)s"
# "%(asctime)s [%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
)

formatter = logging.Formatter(" %(name)s :: %(levelname)-8s :: %(message)s")
handler.setFormatter(formatter)
# logger.addHandler(handler)

from . import seir, model_info
from . import outcomes
from .utils import config, Timer, read_df
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# Logger configuration
import logging
import os
import multiprocessing as mp
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
import numba as nb
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import confuse
from . import inference_parameter, logloss, statistics

# TODO: should be able to draw e.g from an initial condition folder buuut keep the draw as a blob
# so it is saved by emcee, so I can build a posterio
Expand All @@ -66,7 +42,7 @@ def simulation_atomic(
*,
snpi_df_in,
hnpi_df_in,
modinf,
modinf: model_info.ModelInfo,
p_draw,
unique_strings,
transition_array,
Expand Down Expand Up @@ -147,7 +123,7 @@ def simulation_atomic(
return outcomes_df


def get_static_arguments(modinf):
def get_static_arguments(modinf: model_info.ModelInfo):
"""
Get the static arguments for the log likelihood function, these are the same for all walkers
"""
Expand Down
1 change: 1 addition & 0 deletions flepimop/gempyor_pkg/src/gempyor/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from .utils import read_df, write_df


logger = logging.getLogger(__name__)


Expand Down
50 changes: 23 additions & 27 deletions flepimop/gempyor_pkg/src/gempyor/outcomes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import itertools
import time, random
import logging
import time

from numba import jit
import xarray as xr
import numpy as np
import pandas as pd
import pyarrow as pa
import tqdm.contrib.concurrent
import xarray as xr

from .utils import config, Timer, read_df
import pyarrow as pa
import pandas as pd
from . import NPI, model_info


import logging

logger = logging.getLogger(__name__)


def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1):
def run_parallel_outcomes(
modinf: model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1
):
start = time.monotonic()

sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots)
Expand Down Expand Up @@ -71,23 +73,15 @@ def build_outcome_modifiers(
elif load_ID == True:
loaded_df = modinf.read_simID(ftype="hnpi", sim_id=sim_id2load)

if loaded_df is not None:
npi = NPI.NPIBase.execute(
npi_config=modinf.npi_config_outcomes,
modinf=modinf,
modifiers_library=modinf.outcome_modifiers_library,
subpops=modinf.subpop_struct.subpop_names,
loaded_df=loaded_df,
# TODO: support other operation than product
)
else:
npi = NPI.NPIBase.execute(
npi_config=modinf.npi_config_outcomes,
modinf=modinf,
modifiers_library=modinf.outcome_modifiers_library,
subpops=modinf.subpop_struct.subpop_names,
# TODO: support other operation than product
)
npi = NPI.NPIBase.execute(
npi_config=modinf.npi_config_outcomes,
modinf_ti=modinf.ti,
modinf_tf=modinf.tf,
modifiers_library=modinf.outcome_modifiers_library,
subpops=modinf.subpop_struct.subpop_names,
loaded_df=loaded_df,
# TODO: support other operation than product
)
return npi


Expand Down Expand Up @@ -300,7 +294,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo):
return parameters


def postprocess_and_write(sim_id, modinf, outcomes_df, hpar, npi, write=True):
def postprocess_and_write(
sim_id, modinf: model_info.ModelInfo, outcomes_df, hpar, npi, write=True
):
if write:
modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df)
modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar)
Expand Down Expand Up @@ -337,15 +333,15 @@ def dataframe_from_array(data, subpops, dates, comp_name):
return df


def read_seir_sim(modinf, sim_id):
def read_seir_sim(modinf: model_info.ModelInfo, sim_id):
seir_df = modinf.read_simID(ftype="seir", sim_id=sim_id)

return seir_df


def compute_all_multioutcomes(
*,
modinf,
modinf: model_info.ModelInfo,
sim_id2write,
parameters,
loaded_values=None,
Expand Down
41 changes: 6 additions & 35 deletions flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,10 @@
import gempyor
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
import copy

# import seaborn as sns
import matplotlib._color_data as mcd
import pyarrow.parquet as pq
import subprocess
import dask.dataframe as dd
import matplotlib.dates as mdates
import matplotlib.cbook as cbook
from matplotlib.backends.backend_pdf import PdfPages
from gempyor.utils import config, as_list
import os
import multiprocessing as mp
import pandas as pd
import pyarrow.parquet as pq
import xarray as xr
from gempyor import (
config,
model_info,
outcomes,
seir,
inference_parameter,
logloss,
inference,
)
from gempyor.inference import GempyorInference
import tqdm
import os
from multiprocessing import cpu_count
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tqdm

from .model_info import ModelInfo


def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin):
Expand Down Expand Up @@ -145,7 +116,7 @@ def plot_single_chain(frompt, ax, chain, label, gt=None):
plt.close(fig)


def plot_fit(modinf, loss):
def plot_fit(modinf: ModelInfo, loss):
subpop_names = modinf.subpop_struct.subpop_names
fig, axes = plt.subplots(
len(subpop_names),
Expand Down
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/seeding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, TYPE_CHECKING

import numpy as np
import pandas as pd
Expand Down
Loading