diff --git a/hierarc/Sampling/mcmc_sampling.py b/hierarc/Sampling/mcmc_sampling.py index 036c0fe..9547714 100644 --- a/hierarc/Sampling/mcmc_sampling.py +++ b/hierarc/Sampling/mcmc_sampling.py @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs): self.chain = CosmoLikelihood(*args, **kwargs) self.param = self.chain.param - def mcmc_emcee( + def get_emcee_sampler( self, n_walkers, n_burn, @@ -25,7 +25,7 @@ def mcmc_emcee( continue_from_backend=False, **kwargs_emcee ): - """Runs the EMCEE MCMC sampling. + """Runs the EMCEE MCMC sampling and returns the sampler. :param n_walkers: number of walkers :param n_burn: number of iteration of burn in (not stored in the output sample @@ -36,7 +36,7 @@ def mcmc_emcee( :param continue_from_backend: bool, if True and 'backend' in kwargs_emcee, will continue a chain sampling from backend :param kwargs_emcee: keyword argument for the emcee (e.g. to specify backend) - :return: samples of the EMCEE run + :return: sampler of the EMCEE run """ num_param = self.param.num_param @@ -53,6 +53,41 @@ def mcmc_emcee( else: backend.reset(n_walkers, num_param) sampler.run_mcmc(p0, n_burn + n_run, progress=True) + + return sampler + + def mcmc_emcee( + self, + n_walkers, + n_burn, + n_run, + kwargs_mean_start, + kwargs_sigma_start, + continue_from_backend=False, + **kwargs_emcee + ): + """Runs the EMCEE MCMC sampling and returns the flat chain. + + :param n_walkers: number of walkers + :param n_burn: number of iteration of burn in (not stored in the output sample + :param n_run: number of iterations (after burn in) to be sampled + :param kwargs_mean_start: keyword arguments of the mean starting position + :param kwargs_sigma_start: keyword arguments of the spread in the initial + particles per parameter + :param continue_from_backend: bool, if True and 'backend' in kwargs_emcee, will + continue a chain sampling from backend + :param kwargs_emcee: keyword argument for the emcee (e.g. to specify backend) + :return: samples of the EMCEE run + """ + sampler = self.get_emcee_sampler( + n_walkers, + n_burn, + n_run, + kwargs_mean_start, + kwargs_sigma_start, + continue_from_backend=continue_from_backend, + **kwargs_emcee + ) flat_samples = sampler.get_chain(discard=n_burn, thin=1, flat=True) log_prob = sampler.get_log_prob(discard=n_burn, thin=1, flat=True) return flat_samples, log_prob