Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework to make use of frozen SciPy distributions #170

Merged
merged 186 commits into from
Dec 14, 2023
Merged

Conversation

daniel-klein
Copy link
Contributor

This PR includes major updates.

SciPy distributions as parameters

We often used parameters that were later fed into distributions. For example, a parameter named mean_partners could be later fed into a normal distribution to determine the number of partners that each individual would seek.

self.pars['mean_partners'] = 5
...
self.partners = rng.normal(mean=self.pars['mean_partners'], std=3)

We had a nascent effort to create our own "frozen" distribution class to enable

import stisim as ss

self.pars['partner_distrib'] = ss.Normal(mean=5, std=3)
...
self.partners = self.pars.partner_distrib.sample(uids)

But this required us to wrap each numpy distribution individually.

New in this PR we can directly use frozen distributions from SciPy:

import scipy.stats as sps
self.pars['mean_partners'] = sps.norm(loc=5, scale=3)
...
self.partners = self.pars.partner_distrib.rvs(uids)

Callable parameters!

This is so cool... the parameters of SciPy distributions can be callable! This enables users to easily make any distribution parameter vary as function of anything else. Callables can be lambda or functions. For example (from test_base.py):

def test_ppl_construction():

    def init_debut(self, sim, uids):
        loc = np.full(len(uids), 16) # Mean of 16 for male
        loc[sim.people.female[uids]] = 21 # Mean of 21 for female
        return loc

    mf_pars = {
        'debut_dist': sps.norm(loc=init_debut, scale=2),  # Age of debut is normally distributed with mean that varies by sex
    }
    sim_pars = {'networks': [ss.mf(mf_pars)], 'n_agents': 100}
    gon_pars = {'beta': {'mf': [0.08, 0.04]}, 'p_death': 0.2}
    gon = ss.Gonorrhea(pars=gon_pars)

    sim = ss.Sim(pars=sim_pars, diseases=[gon])
    sim.initialize()
    sim.run()
    plt.figure()
    plt.plot(sim.tivec, sim.results.gonorrhea.n_infected)
    plt.title('Number of gonorrhea infections')

    return sim

Here you can see that the call signature is self, sim, uids. While a static method, self is the context of the calling instance, in this case an instance of mf - this is super handy! Users also gain access to the sim and the uids of the agents in question. The return value should be a scalar or array of length len(uids).

This is very powerful! Users can make parameters vary by time, perhaps using interpolation of data or other functional forms, or by any other attribute in the simulation. The callable has access to data in self, in case data or other information is resident on the class object. (This is really just a shortcut to allow self.cd4 in place of e.g. sim.people.hiv.cd4. I am still finding clever ways to use this functionality, one is to enable users to directly override specific scientific methods, like this example from hiv.py:

class HIV(STI):
    def __init__(self, pars=None):
        ...
        self.death_prob_per_dt = sps.bernoulli(p=self.death_prob)
        ...
        return

    @staticmethod
    def death_prob(self, sim, uids):
        return 0.05 / (self.pars.cd4_min - self.pars.cd4_max)**2 *  (self.cd4[uids] - self.pars.cd4_max)**2

Also, this is amazing for initialization. While the default in many modules is sps.bernoulli on the whole population, users could quickly "subtarget" a desired subgroup without changing model code, subclassing, or writing a custom "intervention" acting at time 0 only.

Oh, and users don't have to completely overwrite the distribution. From text_examples.py:

sir.pars['dur_inf'].kwds['loc'] = 5

Finally, I expect this new functionality to be super powerful when used in combination with extra_states.

Common random numbers

SciPy distributions in __dict__ and Module pars are automatically connected to a MultiRNG or SingleRNG automatically. The MultiRNG is now thin as it is now only primarily responsible for setting the seed, jumping, and resets. The SingleRNG is improved as well, and now uses the numpy random singleton directly. The user no longer needs to create random number generators and doesn't even have to specific the (unique) RNG names as these too are determined automatically using the .name and .label. Subclasses should change the label. Please note #167.

Bernoulli filtering

For Bernoulli draws, we previously had the bernoulli_filter. This functionality is now available on top of SciPy distributions:

import scipy.stats as sps
self.will_recover = sps.bernoulli(p=0.8)
...
recover_uids = self.will_recover.filter(uids)

Automatic initialization

I mention this briefly above, but distributions of type rv_frozen are automatically initialized and connected to appropriate RNGs. However, this only works if they are in __dict__ or pars. Other distributions would need to be manually initialized.

Status

Still a work in progress. As of the current state, the code will print lots of "DJK TODO" messages where more attention is needed. I'm working on it! FEEDBACK WELCOME, this is a big one, so let's discuss.

…f variable between simulations. Experimenting in ART.

Added logic to do transmission by acquiring node rather than edge. Slow, but rng safe.

Discovered np.random.poisson in Gonorrhea set_prognosis - likely needs to be passed index of infecting agent, which I do not currently calculate. TODO!
* Redo transmission and identification of source
* Added a poisson stream
* Initializing intervention streams
* Now set prognosis based on infector UID
…f of concept. Added a PrEP intervention that's like ART but modifies rel_sus.
* Breaking relationships on death
* Now calling finalize
* ART reduces transmission
* In ART and PrEP, changed capacity to coverage
* Bug fix in networks
* Updated "simple.py" example, need more testing
…mit that simple.py now runs two simulations and visualizes them side by side for each time step.
* Making ART reduce transmission
* Added a simple "ladder" network in which 1<-->2, 3<-->4, ... for testing purposes
* Improved graph visualization in simple.py, but will likely revert this to a simple test later.
* Adding simple_ladder.py to run and plot the ladder network
* One member of each dyad is initally HIV+
* 50% of HIV+ are on ART
* Added checking to ensure a stream is not called twice in a row without a reset() or a step()
* Improving test for Stream object.
* Improving streams.py:
 - Exception handling
 - reset()
 - Avoiding dependence of Stream on Sim object, now taking a Streams object and a block_size_object.
* module streams by layer, about to fix.
* Checking for repeat stream names.
* Adding run_multistream.py to compare a simple HIV simulation with and without random stream coherence.

* Stream block size now determined automatically

* Stream seed offsets are now determined by hashing the stream name instead of sequentially. Thanks to @RomeshA for the suggestion.

* New transmission calcs. Now determining acquisition probability across all layers and edge directions before using one random draw to determine acquisition.
* This change makes it more challenging to determine whom the infection came from, that is currently a N^2 operation that needs optimization.

* Adding ability to disable multiple streams through parameter named "multistream"
* The ability switch from MultiStream (the default) to CentralizedStream adds some complexity. Consider removing this feature later.
* As a result, everything now has parameters.

* Fixed prevalence calculation to include only people who are alive in the denominator.

* Additional error checking for streams, for example when the index array is boolean or empty.

* The _pre_draw decorator is now applied to more MultiStream calls for error checking.

* New CentralizedStream class that behaves like MultiStream but calls into a centralized random number generator.

* Fixed a small bug in states.py regarding the size of grow().

* Improved stream and streams tests.
* renaming stable_monogamy.py to run_stable_monogamy.py
* Adjusting test_base.py, it need some work to restore test_networks()
* Now using scipy to compute the distance matrix, it's faster
* Added a normal draw to streams.
* run_multistream now compares multstrem on and off for either ART or PrEP
* New run_sweep.py sweeps beta with multstream on or off
Interface still a bit messy.
Need to fix Centralized RNG now.
@cliffckerr cliffckerr merged commit c0063b6 into main Dec 14, 2023
2 checks passed
@cliffckerr cliffckerr deleted the rng-merge-scipy branch December 14, 2023 17:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants