Skip to content

Commit

Permalink
Refactor MCMCSampler class to separate the EMCEE sampling and sampler…
Browse files Browse the repository at this point in the history
… retrieval
  • Loading branch information
ajshajib committed May 16, 2024
1 parent 549246d commit 5fdae6b
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions hierarc/Sampling/mcmc_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs):
self.chain = CosmoLikelihood(*args, **kwargs)
self.param = self.chain.param

def mcmc_emcee(
def get_emcee_sampler(
self,
n_walkers,
n_burn,
Expand All @@ -25,7 +25,7 @@ def mcmc_emcee(
continue_from_backend=False,
**kwargs_emcee
):
"""Runs the EMCEE MCMC sampling.
"""Runs the EMCEE MCMC sampling and returns the sampler.
:param n_walkers: number of walkers
:param n_burn: number of iteration of burn in (not stored in the output sample
Expand All @@ -36,7 +36,7 @@ def mcmc_emcee(
:param continue_from_backend: bool, if True and 'backend' in kwargs_emcee, will
continue a chain sampling from backend
:param kwargs_emcee: keyword argument for the emcee (e.g. to specify backend)
:return: samples of the EMCEE run
:return: sampler of the EMCEE run
"""

num_param = self.param.num_param
Expand All @@ -53,6 +53,41 @@ def mcmc_emcee(
else:
backend.reset(n_walkers, num_param)
sampler.run_mcmc(p0, n_burn + n_run, progress=True)

return sampler

def mcmc_emcee(
self,
n_walkers,
n_burn,
n_run,
kwargs_mean_start,
kwargs_sigma_start,
continue_from_backend=False,
**kwargs_emcee
):
"""Runs the EMCEE MCMC sampling and returns the flat chain.
:param n_walkers: number of walkers
:param n_burn: number of iteration of burn in (not stored in the output sample
:param n_run: number of iterations (after burn in) to be sampled
:param kwargs_mean_start: keyword arguments of the mean starting position
:param kwargs_sigma_start: keyword arguments of the spread in the initial
particles per parameter
:param continue_from_backend: bool, if True and 'backend' in kwargs_emcee, will
continue a chain sampling from backend
:param kwargs_emcee: keyword argument for the emcee (e.g. to specify backend)
:return: samples of the EMCEE run
"""
sampler = self.get_emcee_sampler(
n_walkers,
n_burn,
n_run,
kwargs_mean_start,
kwargs_sigma_start,
continue_from_backend=continue_from_backend,
**kwargs_emcee
)
flat_samples = sampler.get_chain(discard=n_burn, thin=1, flat=True)
log_prob = sampler.get_log_prob(discard=n_burn, thin=1, flat=True)
return flat_samples, log_prob
Expand Down

0 comments on commit 5fdae6b

Please sign in to comment.