Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffckerr committed Dec 12, 2023
2 parents b226f06 + 0ceecef commit 1c5b9d3
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 66 deletions.
2 changes: 1 addition & 1 deletion stisim/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,4 @@ def set_prognoses(self, sim, uids, from_uids=None):
def update_results(self, sim):
self.results['pregnancies'][sim.ti] = np.count_nonzero(self.ti_pregnant == sim.ti)
self.results['births'][sim.ti] = np.count_nonzero(self.ti_delivery == sim.ti)
return
return
2 changes: 1 addition & 1 deletion stisim/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,4 @@ def sample(self, size=None):
vals = self.rng.random(size=n_samples)
vals = self._select(vals, size)
return vals < prob
'''
'''
2 changes: 1 addition & 1 deletion stisim/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def scipy_dbns(self):
:return:
"""
return [x for x in self.__dict__.values() if isinstance(x, ss.ScipyDistribution)] \
+ [x for x in self.pars.values() if isinstance(x, ss.ScipyDistribution)]
+ [x for x in self.pars.values() if isinstance(x, ss.ScipyDistribution)]
13 changes: 6 additions & 7 deletions stisim/networks/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,13 @@ def __init__(self, pars=None):
desired_std = 3
mu = np.log(desired_mean**2 / np.sqrt(desired_mean**2 + desired_std**2))
sigma = np.sqrt(np.log(1 + desired_std**2 / desired_mean**2))
default_pars['duration_dist'] = sps.lognorm(s=sigma, scale=np.exp(mu)), # Can vary by age, year, and individual pair. Set scale=exp(mu) and s=sigma where mu,sigma are of the underlying normal distribution.
default_pars['duration_dist'] = sps.lognorm(s=sigma, scale=np.exp(mu)) # Can vary by age, year, and individual pair. Set scale=exp(mu) and s=sigma where mu,sigma are of the underlying normal distribution.

desired_mean = 18
desired_std = 2
mu = np.log(desired_mean**2 / np.sqrt(desired_mean**2 + desired_std**2))
sigma = np.sqrt(np.log(1 + desired_std**2 / desired_mean**2))
default_pars['debut_dist'] = sps.lognormal(s=sigma, scale=np.exp(mu)),
default_pars['debut_dist'] = sps.lognorm(s=sigma, scale=np.exp(mu))

pars = ss.omerge(default_pars, pars)
DynamicNetwork.__init__(self)
Expand All @@ -470,7 +470,7 @@ def set_network_states(self, people, upper_age=None):
# Participation
self.participant[people.female] = False
pr = self.pars.part_rates
dist = ss.choice([True, False], probabilities=[pr, 1-pr])(len(uids))
dist = sps.bernoulli.rvs(p=pr, size=len(uids))
self.participant[uids] = dist

# Debut
Expand Down Expand Up @@ -590,7 +590,7 @@ def __init__(self, pars=None):
}, pars)
super().__init__(networks=networks, pars=pars)

self.bi_dist = sps.binom(p=self.pars.prop_pi)
self.bi_dist = sps.bernoulli(p=self.pars.prop_bi)
return

def initialize(self, sim):
Expand All @@ -614,7 +614,7 @@ def set_participation(self, people, upper_age=None):
# Male participation rate uses info about cross-network participation.
# First, we determine who's participating in the MSM network
pr = msm.pars.part_rates
dist = ss.choice([True, False], probabilities=[pr, 1-pr])(len(uids))
dist = sps.bernoulli.rvs(p=pr, size=len(uids))
msm.participant[uids] = dist

# Now we take the MSM participants and determine which are also in the MF network
Expand Down Expand Up @@ -896,7 +896,7 @@ def add_pairs(self, mother_inds, unborn_inds, dur):
self.contacts.p2 = np.concatenate([self.contacts.p2, unborn_inds])
self.contacts.beta = np.concatenate([self.contacts.beta, beta])
self.contacts.dur = np.concatenate([self.contacts.dur, dur])
return
return len(mother_inds)

class static(Network):
"""
Expand Down Expand Up @@ -949,4 +949,3 @@ def get_contacts(self):
self.contacts.beta = np.concatenate([self.contacts.beta, np.ones_like(p1s)])
return


77 changes: 40 additions & 37 deletions stisim/people.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,18 @@ def __init__(self, n):
def rngs(self):
return [x for x in self.__dict__.values() if isinstance(x, (ss.MultiRNG, ss.SingleRNG))]

def initialize(self, sim):
""" Initialization """

# For People initialization, first initialize slots, then initialize RNGs, then initialize remaining states
# This is because some states may depend on RNGs being initialized to generate initial values
self.slot.initialize(sim)
self.slot[:] = self.uid

# Initialize all RNGs (noting that includes those that are declared in child classes)
for rng in self.rngs:
rng.initialize(sim.rng_container, self.slot)

# Initialize remaining states
for name, state in self.states.items():
self.add_state(state) # Register the state internally for dynamic growth
state.initialize(sim) # Connect the state to this people instance
setattr(self, name, state)

self.initialized = True
return

def __len__(self):
""" Length of people """
return len(self.uid)

def add_state(self, state):
def add_state(self, state, die=True):
if id(state) not in self._states:
self._states[id(state)] = state
self.states.append(state) # Expose these states with their original names
setattr(self, state.name, state)
elif die:
errormsg = f'Cannot add state {state} since already added'
raise ValueError(errormsg)
return

def grow(self, n, new_slots=None):
Expand Down Expand Up @@ -182,8 +166,6 @@ class People(BasePeople):
ppl = ss.People(2000)
"""

# %% Basic methods

def __init__(self, n, age_data=None, extra_states=None, networks=None, rand_seed=0):
""" Initialize """

Expand All @@ -192,17 +174,20 @@ def __init__(self, n, age_data=None, extra_states=None, networks=None, rand_seed
self.initialized = False
self.version = ss.__version__ # Store version info

self.states.append(ss.State('age', float, 0))
self.states.append(ss.State('female', bool, bernoulli(p=0.5)))
self.states.append(ss.State('debut', float))
self.states.append(ss.State('alive', bool, True)) # Redundant with ti_dead == ss.INT_NAN
self.states.append(ss.State('ti_dead', int, ss.INT_NAN)) # Time index for death
self.states.append(ss.State('scale', float, 1.0))

if extra_states is not None:
for state in extra_states:
self.states.append(state)

# Handle states
states = [
ss.State('age', float, np.nan), # NaN until conceived
ss.State('female', bool, bernoulli(p=0.5)),
ss.State('debut', float),
ss.State('ti_dead', int, ss.INT_NAN), # Time index for death
ss.State('alive', bool, True), # Time index for death
ss.State('scale', float, 1.0),
]
states.extend(sc.promotetolist(extra_states))
for state in states:
self.add_state(state)
self._initialize_states(sim=None) # No sim yet, but initialize what we can

self.networks = ss.Networks(networks)

# Set initial age distribution - likely move this somewhere else later
Expand All @@ -218,10 +203,28 @@ def get_age_dist(age_data):
if sc.checktype(age_data, pd.DataFrame):
return ss.data_dist(vals=age_data['value'].values, bins=age_data['age'].values)

def _initialize_states(self, sim=None):
for state in self.states.values():
state.initialize(sim=sim, people=self) # Connect the state to this people instance
return

def initialize(self, sim):
""" Initialization """
super().initialize(sim)

# For People initialization, first initialize slots, then initialize RNGs, then initialize remaining states
# This is because some states may depend on RNGs being initialized to generate initial values
self.slot.initialize(sim)
self.slot[:] = self.uid

# Initialize all RNGs (noting that includes those that are declared in child classes)
for rng in self.rngs:
rng.initialize(sim.rng_container, self.slot)

# Define age (CK: why is age handled differently than sex?)
self._initialize_states(sim=sim) # Now initialize with the sim
self.age[:] = self.age_data_dist.rvs(len(self))

self.initialized = True
return

def add_module(self, module, force=False):
Expand Down Expand Up @@ -359,4 +362,4 @@ def request_death(self, uids):
# enabled then often such agents would not be present in the simulation anyway
uids = ss.true(self.alive[uids])
self.ti_dead[uids] = self.ti
return
return
2 changes: 1 addition & 1 deletion stisim/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,4 @@ def reset(self):
pass

def step(self, ti):
pass
pass
36 changes: 21 additions & 15 deletions stisim/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_vals_uids(vals, key, uid_map):
"""
Extract values from a collection of UIDs
This function is used to retreive values based on UID. As indexing a FusedArray returns a new FusedArray,
This function is used to retrieve values based on UID. As indexing a FusedArray returns a new FusedArray,
this method also populates the new UID map for use in the subsequently created FusedArray, avoiding the
need to re-compute it separately.
Expand Down Expand Up @@ -373,23 +373,29 @@ def _new_vals(self, uids):
new_vals = self.fill_value
return new_vals

def initialize(self, sim):
def initialize(self, sim=None, people=None):
if self._initialized:
return


if sim is not None and people is None:
people = sim.people

sim_still_needed = False
if isinstance(self.fill_value, rv_frozen):
self.fill_value = ScipyDistribution(self.fill_value, f'{self.__class__.__name__}_{self.label}')
self.fill_value.initialize(sim, self)

people = sim.people

people.add_state(self)
self._uid_map = people._uid_map
self.uid = people.uid
self._data.grow(len(self.uid))
self._data[:len(self.uid)] = self._new_vals(self.uid)
self.values = self._data._view
self._initialized = True
if sim is not None:
self.fill_value = ScipyDistribution(self.fill_value, f'{self.__class__.__name__}_{self.label}')
self.fill_value.initialize(sim, self)
else:
sim_still_needed = True

people.add_state(self, die=False) # CK: should not be needed
if not sim_still_needed:
self._uid_map = people._uid_map
self.uid = people.uid
self._data.grow(len(self.uid))
self._data[:len(self.uid)] = self._new_vals(self.uid)
self.values = self._data._view
self._initialized = True
return

def grow(self, uids):
Expand Down
2 changes: 1 addition & 1 deletion tests/devtests/devtest_remove_people.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ def compare_totals(rand_seed=0):
ax[3].set_title('Number of agents')

fig.tight_layout()
plt.show()
plt.show()
2 changes: 1 addition & 1 deletion tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ def test_seed(n=5):
test_seed(n)

sc.toc(T)
print('Done.')
print('Done.')
2 changes: 1 addition & 1 deletion tests/test_rngcontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ def test_repeatname(rng_container, n=5):
rng1 = ss.RNG('test')
with pytest.raises(RepeatNameException):
rng1.initialize(rng_container, slots=n)
return rng0, rng1
return rng0, rng1

0 comments on commit 1c5b9d3

Please sign in to comment.