Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pearsonca committed Nov 12, 2024
1 parent 73c7234 commit 801b5f6
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 35 deletions.
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from gempyor import model_info


class MultiPeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

from gempyor import model_info


class SinglePeriodModifier(NPIBase):
def __init__(
self,
*,
npi_config,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
*,
npi_config,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gempyor.model_info as model_info


class NPIBase(abc.ABC):
__plugins__ = {}

Expand All @@ -28,7 +29,7 @@ def getReductionDF(self):
def execute(
*,
npi_config,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
modifiers_library,
subpops,
loaded_df=None,
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def simulation_atomic(
*,
snpi_df_in,
hnpi_df_in,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
p_draw,
unique_strings,
transition_array,
Expand Down Expand Up @@ -147,7 +147,7 @@ def simulation_atomic(
return outcomes_df


def get_static_arguments(modinf : model_info.ModelInfo):
def get_static_arguments(modinf: model_info.ModelInfo):
"""
Get the static arguments for the log likelihood function, these are the same for all walkers
"""
Expand Down
16 changes: 11 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
bool
)

def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> np.ndarray:
def get_from_config(self, sim_id: int, modinf: model_info.ModelInfo) -> np.ndarray:
method = "Default"
if (
self.initial_conditions_config is not None
Expand Down Expand Up @@ -129,14 +129,16 @@ def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> np.ndar

return y0

def get_from_file(self, sim_id: int, modinf : model_info.ModelInfo) -> np.ndarray:
def get_from_file(self, sim_id: int, modinf: model_info.ModelInfo) -> np.ndarray:
return self.get_from_config(sim_id=sim_id, modinf=modinf)


# TODO: rename config to initial_conditions_config as it shadows the global config


def check_population(y0, modinf : model_info.ModelInfo, ignore_population_checks : bool = False):
def check_population(
y0, modinf: model_info.ModelInfo, ignore_population_checks: bool = False
):
# check that the inputed values sums to the subpop population:
error = False
for pl_idx, pl in enumerate(modinf.subpop_struct.subpop_names):
Expand All @@ -163,7 +165,11 @@ def check_population(y0, modinf : model_info.ModelInfo, ignore_population_checks


def read_initial_condition_from_tidydataframe(
ic_df, modinf : model_info.ModelInfo, allow_missing_subpops, allow_missing_compartments, proportional_ic=False
ic_df,
modinf: model_info.ModelInfo,
allow_missing_subpops,
allow_missing_compartments,
proportional_ic=False,
):
rests = [] # Places to allocate the rest of the population
y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops))
Expand Down Expand Up @@ -235,7 +241,7 @@ def read_initial_condition_from_tidydataframe(


def read_initial_condition_from_seir_output(
ic_df, modinf : model_info.ModelInfo, allow_missing_subpops, allow_missing_compartments
ic_df, modinf: model_info.ModelInfo, allow_missing_subpops, allow_missing_compartments
):
"""
Read the initial conditions from the SEIR output.
Expand Down
12 changes: 8 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/outcomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
logger = logging.getLogger(__name__)


def run_parallel_outcomes(modinf : model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1):
def run_parallel_outcomes(
modinf: model_info.ModelInfo, *, sim_id2write, nslots=1, n_jobs=1
):
start = time.monotonic()

sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots)
Expand Down Expand Up @@ -300,7 +302,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo):
return parameters


def postprocess_and_write(sim_id, modinf : model_info.ModelInfo, outcomes_df, hpar, npi, write=True):
def postprocess_and_write(
sim_id, modinf: model_info.ModelInfo, outcomes_df, hpar, npi, write=True
):
if write:
modinf.write_simID(ftype="hosp", sim_id=sim_id, df=outcomes_df)
modinf.write_simID(ftype="hpar", sim_id=sim_id, df=hpar)
Expand Down Expand Up @@ -337,15 +341,15 @@ def dataframe_from_array(data, subpops, dates, comp_name):
return df


def read_seir_sim(modinf : model_info.ModelInfo, sim_id):
def read_seir_sim(modinf: model_info.ModelInfo, sim_id):
seir_df = modinf.read_simID(ftype="seir", sim_id=sim_id)

return seir_df


def compute_all_multioutcomes(
*,
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
sim_id2write,
parameters,
loaded_values=None,
Expand Down
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def plot_single_chain(frompt, ax, chain, label, gt=None):
plt.close(fig)


def plot_fit(modinf : model_info.ModelInfo, loss):
def plot_fit(modinf: model_info.ModelInfo, loss):
subpop_names = modinf.subpop_struct.subpop_names
fig, axes = plt.subplots(
len(subpop_names),
Expand Down
6 changes: 3 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
## TODO: ideally here path_prefix should not be used and all files loaded from modinf


def _DataFrame2NumbaDict(df, amounts, modinf : model_info.ModelInfo) -> nb.typed.Dict:
def _DataFrame2NumbaDict(df, amounts, modinf: model_info.ModelInfo) -> nb.typed.Dict:
if not df["date"].is_monotonic_increasing:
raise ValueError(
"_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense"
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(self, config: confuse.ConfigView, path_prefix: str = "."):
self.seeding_config = config
self.path_prefix = path_prefix

def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> nb.typed.Dict:
def get_from_config(self, sim_id: int, modinf: model_info.ModelInfo) -> nb.typed.Dict:
method = "NoSeeding"
if self.seeding_config is not None and "method" in self.seeding_config.keys():
method = self.seeding_config["method"].as_str()
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_from_config(self, sim_id: int, modinf : model_info.ModelInfo) -> nb.type

return _DataFrame2NumbaDict(df=seeding, amounts=amounts, modinf=modinf)

def get_from_file(self, sim_id: int, modinf : model_info.ModelInfo) -> nb.typed.Dict:
def get_from_file(self, sim_id: int, modinf: model_info.ModelInfo) -> nb.typed.Dict:
"""only difference with draw seeding is that the sim_id is now sim_id2load"""
return self.get_from_config(sim_id=sim_id, modinf=modinf)

Expand Down
21 changes: 14 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def build_step_source_arg(
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -118,7 +118,7 @@ def build_step_source_arg(


def steps_SEIR(
modinf : model_info.ModelInfo,
modinf: model_info.ModelInfo,
parsed_parameters,
transition_array,
proportion_array,
Expand Down Expand Up @@ -215,7 +215,14 @@ def steps_SEIR(
return states


def build_npi_SEIR(modinf : model_info.ModelInfo, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None):
def build_npi_SEIR(
modinf: model_info.ModelInfo,
load_ID,
sim_id2load,
config,
bypass_DF=None,
bypass_FN=None,
):
with Timer("SEIR.NPI"):
loaded_df = None
if bypass_DF is not None:
Expand Down Expand Up @@ -335,7 +342,7 @@ def onerun_SEIR(
return out_df


def run_parallel_SEIR(modinf : model_info.ModelInfo, config, *, n_jobs=1):
def run_parallel_SEIR(modinf: model_info.ModelInfo, config, *, n_jobs=1):
start = time.monotonic()
sim_ids = np.arange(1, modinf.nslots + 1)

Expand Down Expand Up @@ -364,7 +371,7 @@ def run_parallel_SEIR(modinf : model_info.ModelInfo, config, *, n_jobs=1):
)


def states2Df(modinf : model_info.ModelInfo, states):
def states2Df(modinf: model_info.ModelInfo, states):
# Tidyup data for R, to save it:
#
# Write output to .snpi.*, .spar.*, and .seir.* files
Expand Down Expand Up @@ -428,7 +435,7 @@ def states2Df(modinf : model_info.ModelInfo, states):
return out_df


def write_spar_snpi(sim_id : int, modinf : model_info.ModelInfo, p_draw, npi):
def write_spar_snpi(sim_id: int, modinf: model_info.ModelInfo, p_draw, npi):
# NPIs
if npi is not None:
modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF())
Expand All @@ -438,7 +445,7 @@ def write_spar_snpi(sim_id : int, modinf : model_info.ModelInfo, p_draw, npi):
)


def write_seir(sim_id, modinf : model_info.ModelInfo, states):
def write_seir(sim_id, modinf: model_info.ModelInfo, states):
# print_disk_diagnosis()
out_df = states2Df(modinf, states)
modinf.write_simID(ftype="seir", sim_id=sim_id, df=out_df)
Expand Down
8 changes: 7 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@


class SubpopulationStructure:
def __init__(self, *, setup_name : str, subpop_config : confuse.Subview, path_prefix=pathlib.Path(".")):
def __init__(
self,
*,
setup_name: str,
subpop_config: confuse.Subview,
path_prefix=pathlib.Path("."),
):
"""Important attributes:
- self.setup_name: Name of the setup
- self.data: DataFrame with subpopulations and populations
Expand Down
16 changes: 8 additions & 8 deletions flepimop/gempyor_pkg/src/gempyor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def command_safe_run(
return sr.returncode, stdout, stderr


def add_method(cls : Any):
def add_method(cls: Any):
"""
A function which adds a function to a class.
Expand Down Expand Up @@ -864,13 +864,13 @@ def get_filetype_for_resume(


def create_resume_file_names_map(
resume_discard_seeding : str,
flepi_block_index : str,
resume_run_index : str,
flepi_prefix : str,
flepi_slot_index : str,
flepi_run_index : str,
last_job_output : str,
resume_discard_seeding: str,
flepi_block_index: str,
resume_run_index: str,
flepi_prefix: str,
flepi_slot_index: str,
flepi_run_index: str,
last_job_output: str,
) -> dict[str, str]:
"""
Generates a mapping of input file names to output file names for a resume process based on
Expand Down

0 comments on commit 801b5f6

Please sign in to comment.