Skip to content

Commit

Permalink
Apply black formatter post-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyWillard committed Nov 19, 2024
1 parent 0c67060 commit 029f8c8
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 200 deletions.
4 changes: 3 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def __get_affected_subpops_grp(self, grp_config):
affected_subpops_grp = self.subpops
else:
affected_subpops_grp = [str(n.get()) for n in grp_config["subpop"]]
affected_subpops_grp = list(set(affected_subpops_grp).intersection(self.subpops))
affected_subpops_grp = list(
set(affected_subpops_grp).intersection(self.subpops)
)

return affected_subpops_grp

Expand Down
8 changes: 6 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def __createFromConfig(self, npi_config):
if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all":
self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]}
# subsamples the subpopulations to only the ones specified in the geodata (self.subpops)
self.affected_subpops = list(set(self.affected_subpops).intersection(self.subpops))
self.affected_subpops = list(
set(self.affected_subpops).intersection(self.subpops)
)

self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)]
# Create reduction
Expand Down Expand Up @@ -166,7 +168,9 @@ def __createFromDf(self, loaded_df, npi_config):
self.affected_subpops = set(self.subpops)
if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all":
self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]}
self.affected_subpops = list(set(self.affected_subpops).intersection(self.subpops))
self.affected_subpops = list(
set(self.affected_subpops).intersection(self.subpops)
)

self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)]
self.parameters["modifier_name"] = self.name
Expand Down
232 changes: 138 additions & 94 deletions flepimop/gempyor_pkg/src/gempyor/NPI/base.py

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def get_spatial_groups(grp_config, affected_subpops: list) -> dict:
flat_grouped_list = flatten_list_of_lists(spatial_groups["grouped"])
# check that all subpops are either grouped or ungrouped

#if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops):
# if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops):
# print("set of grouped and ungrouped subpops", set(flat_grouped_list + spatial_groups["ungrouped"]))
# print("set of affected subpops ", set(affected_subpops))
# raise ValueError(f"The two above sets are differs for for intervention with config \n {grp_config}")
#if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len(
# if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len(
# flat_grouped_list + spatial_groups["ungrouped"]
#):
# ):
# raise ValueError(
# f"subpop_groups error. For intervention with config \n {grp_config} \n duplicate entries in the set of grouped and ungrouped subpops"
# f" {flat_grouped_list + spatial_groups['ungrouped']} vs {set(flat_grouped_list + spatial_groups['ungrouped'])}"
Expand All @@ -66,8 +66,13 @@ def get_spatial_groups(grp_config, affected_subpops: list) -> dict:
spatial_groups["grouped"] = make_list_of_list(spatial_groups["grouped"])

# sort the lists
spatial_groups["grouped"] = [sorted(list(set(x).intersection(affected_subpops))) for x in spatial_groups["grouped"]]
spatial_groups["ungrouped"] = sorted(list(set(spatial_groups["ungrouped"]).intersection(affected_subpops)))
spatial_groups["grouped"] = [
sorted(list(set(x).intersection(affected_subpops)))
for x in spatial_groups["grouped"]
]
spatial_groups["ungrouped"] = sorted(
list(set(spatial_groups["ungrouped"]).intersection(affected_subpops))
)

# remove empty sublist in grp{'grouped': [[], ['01000_14to15', '01000_17to18', '01000_22to23'], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []],
spatial_groups["grouped"] = [x for x in spatial_groups["grouped"] if x]
Expand Down
107 changes: 70 additions & 37 deletions flepimop/gempyor_pkg/src/gempyor/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ def calibrate(
else:
run_id = input_run_id



# Select a file name and create the backend/resume
if resume or resume_location is not None:
if resume_location is None:
Expand All @@ -145,14 +143,16 @@ def calibrate(
if not os.path.exists(filename):
print(f"File {filename} does not exist, cannot resume")
return
print(f"Doing a resume from {filename}, this only work with the same number of slot/walkers and parameters right now")
print(
f"Doing a resume from {filename}, this only work with the same number of slot/walkers and parameters right now"
)
else:
filename = f"{run_id}_backend.h5"
if os.path.exists(filename):
if not resume:
print(f"File {filename} already exists, remove it or use --resume")
return

gempyor_inference = GempyorInference(
config_filepath=config_filepath,
run_id=run_id,
Expand Down Expand Up @@ -194,45 +194,64 @@ def calibrate(
gempyor_inference.perform_test_run()
with multiprocessing.Pool(ncpu) as pool:
lliks = pool.starmap(
gempyor_inference.get_logloss_as_single_number, [(p0[0],), (p0[0],), (p0[1],),]
gempyor_inference.get_logloss_as_single_number,
[
(p0[0],),
(p0[0],),
(p0[1],),
],
)
if lliks[0] != lliks[1]:
print(f"Test run failed, logloss with the same parameters is different: {lliks[0]} != {lliks[1]} ❌")
print("This means that there is config variability not captured in the emcee fits")
print(
f"Test run failed, logloss with the same parameters is different: {lliks[0]} != {lliks[1]} ❌"
)
print(
"This means that there is config variability not captured in the emcee fits"
)
return
# TODO THIS Test in fact does nnot work.
else:
print(f"Test run done, logloss with same parameters: {lliks[0]}=={lliks[1]} ✅ ")
#assert lliks[1] != lliks[2], "Test run failed, logloss with different parameters is the same, perturbation are not taken into account"


print(
f"Test run done, logloss with same parameters: {lliks[0]}=={lliks[1]} ✅ "
)
# assert lliks[1] != lliks[2], "Test run failed, logloss with different parameters is the same, perturbation are not taken into account"

# Make a plot of the runs directly from config
n_config_samples = min(30, nwalkers // 2)
print(f"Making {n_config_samples} simulations from config to plot")
with multiprocessing.Pool(ncpu) as pool:
results = pool.starmap(
gempyor_inference.simulate_proposal, [(p0[i],) for i in range(n_config_samples)]
)
gempyor.postprocess_inference.plot_fit(modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
plot_projections=True,
list_of_df=results, save_to=f"{run_id}_config.pdf")

gempyor.postprocess_inference.plot_fit(
modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
plot_projections=True,
list_of_df=results,
save_to=f"{run_id}_config.pdf",
)

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# @JOSEPH: find below a "cocktail" move proposal
moves = [(emcee.moves.DEMove(live_dangerously=True), 0.5*0.5*0.5),
(emcee.moves.DEMove(gamma0=1.0,live_dangerously=True),0.5*0.5*0.5),
(emcee.moves.DESnookerMove(live_dangerously=True),0.5*0.5), # First three moves: DEMove --> DE is good at "optimizing". Moves based on the (really great!) discussion in https://groups.google.com/g/emcee-users/c/FCAq459Y9OE
(emcee.moves.StretchMove(live_dangerously=True), 0.5), # Stretch gives good chain movement
#(emcee.moves.KDEMove(live_dangerously=True, bw_method='scott'), 0.25)
] # Based on personal experience with pySODM (Tijs) - KDEMove works really well but I think it's important for this one to have at least 3x more walkers than parameters.
# @JOSEPH: find below a "cocktail" move proposal
moves = [
(emcee.moves.DEMove(live_dangerously=True), 0.5 * 0.5 * 0.5),
(emcee.moves.DEMove(gamma0=1.0, live_dangerously=True), 0.5 * 0.5 * 0.5),
(
emcee.moves.DESnookerMove(live_dangerously=True),
0.5 * 0.5,
), # First three moves: DEMove --> DE is good at "optimizing". Moves based on the (really great!) discussion in https://groups.google.com/g/emcee-users/c/FCAq459Y9OE
(
emcee.moves.StretchMove(live_dangerously=True),
0.5,
), # Stretch gives good chain movement
# (emcee.moves.KDEMove(live_dangerously=True, bw_method='scott'), 0.25)
] # Based on personal experience with pySODM (Tijs) - KDEMove works really well but I think it's important for this one to have at least 3x more walkers than parameters.
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
#moves = [(emcee.moves.StretchMove(live_dangerously=True), 1)]
# moves = [(emcee.moves.StretchMove(live_dangerously=True), 1)]

gempyor_inference.set_silent(False)
#with multiprocessing.Pool(ncpu) as pool:

# with multiprocessing.Pool(ncpu) as pool:
# sampler = emcee.EnsembleSampler(
# nwalkers,
# gempyor_inference.inferpar.get_dim(),
Expand Down Expand Up @@ -261,13 +280,19 @@ def calibrate(
backend=backend,
moves=moves,
)
state = sampler.run_mcmc(start_val, nbatch, progress=True, skip_initial_state_check=True)
state = sampler.run_mcmc(
start_val, nbatch, progress=True, skip_initial_state_check=True
)
print(f"Done, mean acceptance fraction: {np.mean(sampler.acceptance_fraction):.3f}")

# plotting the chain
sampler = emcee.backends.HDFBackend(filename, read_only=True)
gempyor.postprocess_inference.plot_chains(
inferpar=gempyor_inference.inferpar, chains = sampler.get_chain(), llik = sampler.get_log_prob(), sampled_slots=None, save_to=f"{run_id}_chains.pdf"
inferpar=gempyor_inference.inferpar,
chains=sampler.get_chain(),
llik=sampler.get_log_prob(),
sampled_slots=None,
save_to=f"{run_id}_chains.pdf",
)
print("EMCEE Run done, doing sampling")

Expand All @@ -286,19 +311,27 @@ def calibrate(
)

results = []
for fn in gempyor.utils.list_filenames(folder=os.path.join(project_path,"model_output/"), filters=[run_id, "hosp.parquet"]):
for fn in gempyor.utils.list_filenames(
folder=os.path.join(project_path, "model_output/"), filters=[run_id, "hosp.parquet"]
):
df = gempyor.read_df(fn)
df = df.set_index("date")
results.append(df)

gempyor.postprocess_inference.plot_fit(modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
list_of_df=results, save_to=f"{run_id}_fit.pdf")

gempyor.postprocess_inference.plot_fit(modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
plot_projections=True,
list_of_df=results, save_to=f"{run_id}_fit_w_proj.pdf")
gempyor.postprocess_inference.plot_fit(
modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
list_of_df=results,
save_to=f"{run_id}_fit.pdf",
)

gempyor.postprocess_inference.plot_fit(
modinf=gempyor_inference.modinf,
loss=gempyor_inference.logloss,
plot_projections=True,
list_of_df=results,
save_to=f"{run_id}_fit_w_proj.pdf",
)


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def simulation_atomic(
seeding_amounts,
outcomes_parameters,
save=False,

):
# We need to reseed because subprocess inherit of the same random generator state.
np.random.seed(int.from_bytes(os.urandom(4), byteorder="little"))
Expand Down Expand Up @@ -407,7 +406,6 @@ def __init__(
subpop_struct=self.modinf.subpop_struct,
time_setup=self.modinf.time_setup,
)


print("Running Gempyor Inference")
print(self.logloss)
Expand All @@ -424,8 +422,15 @@ def set_save(self, save):

def get_all_sim_arguments(self):
# inferpar, logloss, static_sim_arguments, modinf, proposal, silent, save
return [self.inferpar, self.logloss, self.static_sim_arguments, self.modinf, self.silent, self.save]

return [
self.inferpar,
self.logloss,
self.static_sim_arguments,
self.modinf,
self.silent,
self.save,
]

def simulate_proposal(self, proposal):
snpi_df_mod, hnpi_df_mod = self.inferpar.inject_proposal(
proposal=proposal,
Expand All @@ -448,7 +453,7 @@ def get_logloss(self, proposal):
if not self.silent:
print("OUT OF BOUND!!")
return -np.inf, -np.inf, -np.inf

outcomes_df = self.simulate_proposal(proposal=proposal)

ll_total, logloss, regularizations = self.logloss.compute_logloss(
Expand Down
20 changes: 13 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/inference_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ def add_modifier(self, pname, ptype, parameter_config, subpops):
# identify spatial group
affected_subpops = set(subpops)

if parameter_config["method"].get() == "SinglePeriodModifier":
if parameter_config["subpop"].exists() and parameter_config["subpop"].get() != "all":
if parameter_config["method"].get() == "SinglePeriodModifier":
if (
parameter_config["subpop"].exists()
and parameter_config["subpop"].get() != "all"
):
affected_subpops = {str(n.get()) for n in parameter_config["subpop"]}
spatial_groups = NPI.helpers.get_spatial_groups(parameter_config, list(affected_subpops))
spatial_groups = NPI.helpers.get_spatial_groups(
parameter_config, list(affected_subpops)
)

# ungrouped subpop (all affected subpop by default) have one parameter per subpop
if spatial_groups["ungrouped"]:
Expand All @@ -64,7 +69,7 @@ def add_modifier(self, pname, ptype, parameter_config, subpops):
pdist=parameter_config["value"].as_random_distribution(),
lb=parameter_config["value"]["a"].get(),
ub=parameter_config["value"]["b"].get(),
)
)
elif parameter_config["method"].get() == "MultiPeriodModifier":
affected_subpops_grp = []
for grp_config in parameter_config["groups"]:
Expand All @@ -80,7 +85,9 @@ def add_modifier(self, pname, ptype, parameter_config, subpops):
else:
affected_subpops_grp = [str(n.get()) for n in grp_config["subpop"]]

this_spatial_group = NPI.helpers.get_spatial_groups(grp_config, affected_subpops_grp)
this_spatial_group = NPI.helpers.get_spatial_groups(
grp_config, affected_subpops_grp
)

# ungrouped subpop (all affected subpop by default) have one parameter per subpop
if this_spatial_group["ungrouped"]:
Expand All @@ -104,11 +111,10 @@ def add_modifier(self, pname, ptype, parameter_config, subpops):
pdist=parameter_config["value"].as_random_distribution(),
lb=parameter_config["value"]["a"].get(),
ub=parameter_config["value"]["b"].get(),
)
)
else:
raise ValueError(f"Unknown method {parameter_config['method']}")


def add_single_parameter(self, ptype, pname, subpop, pdist, lb, ub):
"""
Adds a single parameter to the parameters list.
Expand Down
Loading

0 comments on commit 029f8c8

Please sign in to comment.