Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffckerr committed Dec 11, 2023
2 parents 41cf86b + 38e531d commit a1f7b1c
Show file tree
Hide file tree
Showing 27 changed files with 1,134 additions and 1,156 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@
'matplotlib',
'seaborn',
'numba',
'networkx',
],
)
167 changes: 44 additions & 123 deletions stisim/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import stisim as ss
import sciris as sc
import pandas as pd
import scipy.stats as sps

__all__ = ['DemographicModule', 'births', 'background_deaths', 'Pregnancy']

Expand Down Expand Up @@ -104,75 +105,21 @@ def __init__(self, pars=None):
super().__init__(pars)

self.pars = ss.omerge({
'death_rates': 0,
'rel_death': 1,
'data_cols': {'year': 'Time', 'sex': 'Sex', 'age': 'AgeGrpStart', 'value': 'mx'},
'sex_keys': {'f': 'Female', 'm': 'Male'},
'units_per_100': 1e-3 # assumes birth rates are per 1000. If using percentages, switch this to 1
#'rel_death': 1,
#'units_per_100': 1e-3, # assumes birth rates are per 1000. If using percentages, switch this to 1
'death_prob_func': self.death_prob,
}, self.pars)

# Validate death rate inputs
self.set_death_rates(self.pars['death_rates'])

# Set death probs
self.death_probs = ss.State('death_probs', float, 0)

# Define random number generators
self.rng_dead = ss.RNG(f'dead_{self.name}')

self.death_prob_dist = sps.bernoulli(p=self.pars['death_prob_func'])
return

def set_death_rates(self, death_rates):
"""Standardize/validate death rates"""

if sc.checktype(death_rates, pd.DataFrame):
if not set(self.pars.data_cols.values()).issubset(death_rates.columns):
errormsg = 'Please ensure the columns of the death rate data match the values in pars.data_cols.'
raise ValueError(errormsg)
df = death_rates

elif sc.checktype(death_rates, pd.Series):
if (death_rates.index < 120).all(): # Assume index is age bins
df = pd.DataFrame({
self.pars.data_cols['year']: 2000,
self.pars.data_cols['age']: death_rates.index.values,
self.pars.data_cols['value']: death_rates.values,
})
elif (death_rates.index > 1900).all(): # Assume index year
df = pd.DataFrame({
self.pars.data_cols['year']: death_rates.index.values,
self.pars.data_cols['age']: 0,
self.pars.data_cols['value']: death_rates.values,

})
else:
errormsg = 'Could not understand index of death rate series: should be age or year.'
raise ValueError(errormsg)

df = pd.concat([df, df])
df[self.pars.data_cols['sex']] = np.repeat(list(self.pars.sex_keys.values()), len(death_rates))

elif sc.checktype(death_rates, dict):
if not set(self.pars.data_cols.values()).issubset(death_rates.keys()):
errormsg = 'Please ensure the keys of the death rate data dict match the values in pars.data_cols.'
raise ValueError(errormsg)
df = pd.DataFrame(death_rates)

elif sc.isnumber(death_rates):
df = pd.DataFrame({
self.pars.data_cols['year']: [2000, 2000],
self.pars.data_cols['age']: [0, 0],
self.pars.data_cols['sex']: self.pars.sex_keys.values(),
self.pars.data_cols['value']: [death_rates, death_rates],
})

else:
errormsg = f'Death rate data type {type(death_rates)} not understood.'
raise ValueError(errormsg)

self.pars.death_rates = df

return
@staticmethod
def death_prob(self, sim, uids):
#p = self.pars
#year = p.death_rates[p.data_cols['year']]
#val = p.death_rates[p.data_cols['cbr']]
#this_death_rate = np.interp(sim.year, year, val) * p.units_per_100 * p.rel_death * sim.pars.dt_demog
#n_new = int(np.floor(np.count_nonzero(sim.people.alive) * this_death_rate))
return 0.01

def init_results(self, sim):
self.results += ss.Result(self.name, 'new', sim.npts, dtype=int)
Expand All @@ -187,31 +134,8 @@ def update(self, sim):

def apply_deaths(self, sim):
""" Select people to die """

p = self.pars
year_label = p.data_cols['year']
age_label = p.data_cols['age']
sex_label = p.data_cols['sex']
val_label = p.data_cols['value']
sex_keys = p.sex_keys

available_years = p.death_rates[year_label].unique()
year_ind = sc.findnearest(available_years, sim.year)
nearest_year = available_years[year_ind]

df = p.death_rates.loc[p.death_rates[year_label] == nearest_year]
age_bins = df[age_label].unique()
age_inds = np.digitize(sim.people.age, age_bins) - 1

f_arr = df[val_label].loc[df[sex_label] == sex_keys['f']].values
m_arr = df[val_label].loc[df[sex_label] == sex_keys['m']].values
self.death_probs[sim.people.female] = f_arr[age_inds[sim.people.female]]
self.death_probs[sim.people.male] = m_arr[age_inds[sim.people.male]]
self.death_probs *= p.rel_death # Adjust overall death probabilities

# Get indices of people who die of other causes
alive_uids = ss.true(sim.people.alive)
death_uids = self.rng_dead.bernoulli_filter(self.death_probs[alive_uids], alive_uids)
death_uids = self.death_prob_dist.filter(alive_uids)
sim.people.request_death(death_uids)

return len(death_uids)
Expand Down Expand Up @@ -242,21 +166,19 @@ def __init__(self, pars=None):
self.pars = ss.omerge({
'dur_pregnancy': 0.75, # Make this a distribution?
'dur_postpartum': 0.5, # Make this a distribution?
'inci': 0.03, # Replace this with age-specific rates
'p_death': ss.bernoulli(0), # Probability of maternal death. Question, should this be linked to age and/or duration?
'pregnancy_prob_per_dt': sps.bernoulli(p=0.3), # Probabilty of acquiring pregnancy on each time step. Can replace with callable parameters for age-specific rates, etc. NOTE: You will manually need to adjust for the simulation timestep dt!
'p_death': sps.bernoulli(p=0.15), # Probability of maternal death.
'p_female': sps.bernoulli(p=0.5),
'init_prev': 0.3, # Number of women initially pregnant # TODO: Default value
}, self.pars)

self.rng_female = ss.RNG(f'female_{self.name}')
self.female_dist = ss.bernoulli(p=0.5)
self.rng_conception = ss.RNG('conception')
self.rng_dead = ss.RNG(f'dead_{self.name}')
self.rng_choose_slots = ss.RNG('choose_slots')

self.choose_slots = sps.randint(low=0, high=1) # Low and high will be reset upon initialization
return

def initialize(self, sim):
super().initialize(sim)
self.choose_slots.kwds['low'] = sim.pars['n_agents']+1
self.choose_slots.kwds['high'] = int(sim.pars['slot_scale']*sim.pars['n_agents'])
sim.pars['birth_rates'] = None # This turns off birth rate pars so births only come from this module
return

Expand Down Expand Up @@ -313,35 +235,34 @@ def make_pregnancies(self, sim):

# If incidence of pregnancy is non-zero, make some cases
# Think about how to deal with age/time-varying fertility
if self.pars.inci > 0:
denom_conds = ppl.female & ppl.active & self.susceptible
inds_to_choose_from = ss.true(denom_conds)
uids = self.rng_conception.bernoulli_filter(ss.bernoulli(self.pars.inci), inds_to_choose_from)
denom_conds = ppl.female & ppl.active & self.susceptible
inds_to_choose_from = ss.true(denom_conds)
uids = self.pars['pregnancy_prob_per_dt'].filter(inds_to_choose_from)

# Add UIDs for the as-yet-unborn agents so that we can track prognoses and transmission patterns
n_unborn_agents = len(uids)
if n_unborn_agents > 0:
# Add UIDs for the as-yet-unborn agents so that we can track prognoses and transmission patterns
n_unborn_agents = len(uids)
if n_unborn_agents > 0:

# Choose slots for the unborn agents
new_slots = self.rng_choose_slots.sample(ss.uniform_int(sim.pars['n_agents'],sim.pars['slot_scale']*sim.pars['n_agents']), uids)
# Choose slots for the unborn agents
new_slots = self.choose_slots.rvs(uids)

# Grow the arrays and set properties for the unborn agents
new_uids = sim.people.grow(len(new_slots))
# Grow the arrays and set properties for the unborn agents
new_uids = sim.people.grow(len(new_slots))

sim.people.age[new_uids] = -self.pars.dur_pregnancy
sim.people.slot[new_uids] = new_slots # Before sampling female_dist
sim.people.female[new_uids] = self.rng_female.sample(self.female_dist,uids)
sim.people.age[new_uids] = -self.pars.dur_pregnancy
sim.people.slot[new_uids] = new_slots # Before sampling female_dist
sim.people.female[new_uids] = self.pars['p_female'].rvs(uids)

# Add connections to any vertical transmission layers
# Placeholder code to be moved / refactored. The maternal network may need to be
# handled separately to the sexual networks, TBC how to handle this most elegantly
for lkey, layer in sim.people.networks.items():
if layer.vertical: # What happens if there's more than one vertical layer?
durs = np.full(n_unborn_agents, fill_value=self.pars.dur_pregnancy + self.pars.dur_postpartum)
layer.add_pairs(uids, new_uids, dur=durs)
# Add connections to any vertical transmission layers
# Placeholder code to be moved / refactored. The maternal network may need to be
# handled separately to the sexual networks, TBC how to handle this most elegantly
for lkey, layer in sim.people.networks.items():
if layer.vertical: # What happens if there's more than one vertical layer?
durs = np.full(n_unborn_agents, fill_value=self.pars.dur_pregnancy + self.pars.dur_postpartum)
layer.add_pairs(uids, new_uids, dur=durs)

# Set prognoses for the pregnancies
self.set_prognoses(sim, uids) # Could set from_uids to network partners?
# Set prognoses for the pregnancies
self.set_prognoses(sim, uids) # Could set from_uids to network partners?

return

Expand All @@ -360,7 +281,7 @@ def set_prognoses(self, sim, uids, from_uids=None):

# Outcomes for pregnancies
dur = np.full(len(uids), sim.ti + self.pars.dur_pregnancy / sim.dt)
dead = self.rng_dead.sample(self.pars.p_death, uids)
dead = self.pars.p_death.rvs(uids)
self.ti_delivery[uids] = dur # Currently assumes maternal deaths still result in a live baby
dur_post_partum = np.full(len(uids), dur + self.pars.dur_postpartum / sim.dt)
self.ti_postpartum[uids] = dur_post_partum
Expand All @@ -372,4 +293,4 @@ def set_prognoses(self, sim, uids, from_uids=None):
def update_results(self, sim):
self.results['pregnancies'][sim.ti] = np.count_nonzero(self.ti_pregnant == sim.ti)
self.results['births'][sim.ti] = np.count_nonzero(self.ti_delivery == sim.ti)
return
return
Loading

0 comments on commit a1f7b1c

Please sign in to comment.