From 2d259caf6b47d5e81ac9542bb5d62f4de592172d Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Wed, 20 Nov 2024 18:01:36 +0100 Subject: [PATCH 01/17] Changed rescale-method of JointPrior to always return correct-size array and update in-place once all keys are requested. Changed (Conditional)PriorDict.rescale to always return samples in right shape. --- bilby/core/prior/dict.py | 35 ++++++++++++++------------ bilby/core/prior/joint.py | 52 ++++++++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index be3d543a..2908bd08 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -600,18 +600,21 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is 1D, returns list of floats containing the rescaled sample. + If theta is 2D, returns list of lists containing the rescaled samples. """ - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) + # turns 0d-arrays into scalars + samples.append(np.squeeze(samps).tolist()) return samples def test_redundancy(self, key, disable_logging=False): @@ -832,28 +835,28 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is float for each key, returns list of floats containing the rescaled sample. + If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): + result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) self[key].least_recently_sampled = result[key] samples = [] for key in keys: - samples += list(np.asarray(result[key]).flatten()) + # turns 0d-arrays into scalars + res = np.squeeze(result[key]).tolist() + samples.append(res) return samples def _update_rescale_keys(self, keys): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e..b088b15b 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,8 +63,9 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescaled parameters - self.rescale_parameters = dict() + # a dictionary of the rescale(d) parameters + self._rescale_parameters = dict() + self._rescaled_parameters = dict() self.reset_rescale() # a list of sampled parameters @@ -94,7 +95,12 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self.rescale_parameters.values()]) + return not np.any([val is None for val in self._rescale_parameters.values()]) + + def set_rescale(self, key, values): + values = np.array(values) + self._rescale_parameters[key] = values + self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan def reset_rescale(self): """ @@ -102,7 +108,11 @@ def reset_rescale(self): """ for name in self.names: - self.rescale_parameters[name] = None + self._rescale_parameters[name] = None + self._rescaled_parameters[name] = None + + def get_rescaled(self, key): + return self._rescaled_parameters[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -303,10 +313,11 @@ def rescale(self, value, **kwargs): Parameters ========== - value: array - A 1d vector sample (one for each parameter) drawn from a uniform + value: array or None + If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. + If None, values previously set using BaseJointPriorDist.set_rescale() are used. kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -317,7 +328,11 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + if value is None: + samp = np.array(list(self._rescale_parameters.values())).T + else: + samp = np.array(value) + if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,6 +342,11 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) + if value is None: + for i, key in enumerate(self.names): + output = self.get_rescaled(key) + # update in-place for proper handling in PriorDict-instances + output[:] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -790,19 +810,23 @@ def rescale(self, val, **kwargs): all kwargs passed to the dist.rescale method Returns ======= - float: - A sample from the prior parameter. + np.ndarray: + The samples from the prior parameter. If not all names in "dist" have been filled, + the array contains only np.nan. *This* specific array instance will be filled with + the rescaled value once all parameters have been requested """ - self.dist.rescale_parameters[self.name] = val + self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T - samples = self.dist.rescale(values, **kwargs) + self.dist.rescale(values=None, **kwargs) + output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() - return samples else: - return [] # return empty list + output = self.dist.get_rescaled(self.name) + + # have to return raw output to conserve in-place modifications + return output def sample(self, size=1, **kwargs): """ From 5ce0bbe40a7f169181c4676611b58015bd026eb8 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:55:06 +0100 Subject: [PATCH 02/17] Small fix to rescale of JointPrior --- bilby/core/prior/joint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index b088b15b..7e6655ee 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -819,7 +819,7 @@ def rescale(self, val, **kwargs): self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - self.dist.rescale(values=None, **kwargs) + self.dist.rescale(value=None, **kwargs) output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() else: From 4e1a3d2d50f5abc97a3cb95153b7dac1eefa410b Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Mon, 25 Nov 2024 16:50:02 +0100 Subject: [PATCH 03/17] For jointprior rescale, only cast to list once its save to loose mutability --- bilby/core/prior/dict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 2908bd08..0490d194 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -613,8 +613,10 @@ def rescale(self, keys, theta): samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) + samples.append(samps) + for i, samps in enumerate(samples): # turns 0d-arrays into scalars - samples.append(np.squeeze(samps).tolist()) + samples[i] = np.squeeze(samps).tolist() return samples def test_redundancy(self, key, disable_logging=False): From 35ac877edd82df445e0449f8b07d25aebe23b50c Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:57:00 +0100 Subject: [PATCH 04/17] Fix to BaseJointPriorDist bound check --- bilby/core/prior/joint.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 7e6655ee..a4520d39 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -219,11 +219,9 @@ def _check_samp(self, value): raise ValueError("Array is the wrong shape") # check sample(s) is within bounds - outbounds = np.ones(samp.shape[0], dtype=bool) + outbounds = np.zeros(samp.shape[0], dtype=bool) for s, bound in zip(samp.T, self.bounds.values()): - outbounds = (s < bound[0]) | (s > bound[1]) - if np.any(outbounds): - break + outbounds += (s < bound[0]) | (s > bound[1]) return samp, outbounds def ln_prob(self, value): From 09bc43b34eced33cacbd7aa6294762e38bddf5bf Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 09:58:11 +0100 Subject: [PATCH 05/17] Allow setting of "dist" attributes through JointPrior --- bilby/core/prior/joint.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index a4520d39..8d2d9188 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -745,7 +745,7 @@ class MultivariateNormalDist(MultivariateGaussianDist): class JointPrior(Prior): - def __init__(self, dist, name=None, latex_label=None, unit=None): + def __init__(self, dist, name=None, latex_label=None, unit=None, **kwargs): """This defines the single parameter Prior object for parameters that belong to a JointPriorDist Parameters @@ -796,6 +796,17 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) + def __setattr__(self, name, value): + # first check that the JointPrior has an explicit setter method for the attribute, which should take presedence + if hasattr(self.__class__, name) and getattr(self.__class__, name).fset is not None: + return super().__setattr__(name, value) + # then check if the BaseJointPriorDist-!subclass! has an explicit setter method for the attribute + elif hasattr(self, "dist") and hasattr(self.dist, name) and getattr(self.dist.__class__, name).fset is not None: + return self.dist.__setattr__(name, value) + # if not, use the default settattr + else: + return super().__setattr__(name, value) + def rescale(self, val, **kwargs): """ Scale a unit hypercube sample to the prior. From 7c746831aa3875d7bc9b35ad41676e87e5809a41 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:02:38 +0100 Subject: [PATCH 06/17] Make ConditionalPrior ready for JointPrior --- bilby/core/prior/conditional.py | 34 +++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 797cbd1c..1a6ca735 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -3,13 +3,14 @@ from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from ..utils import infer_args_from_method, infer_parameters_from_function +from .joint import JointPrior +from ..utils import infer_args_from_method, infer_parameters_from_function, get_dict_with_properties def conditional_prior_factory(prior_class): class ConditionalPrior(prior_class): - def __init__(self, condition_func, name=None, latex_label=None, unit=None, - boundary=None, **reference_params): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, dist=None, + **reference_params): """ Parameters @@ -41,23 +42,26 @@ def condition_func(reference_params, y): See superclass boundary: str, optional See superclass + dist: BaseJointPriorDist, optional + See superclass reference_params: Initial values for attributes such as `minimum`, `maximum`. This differs on the `prior_class`, for example for the Gaussian prior this is `mu` and `sigma`. """ - if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, boundary=boundary, **reference_params) - else: - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, **reference_params) + kwargs = {"name": name, "latex_label": latex_label, "unit": unit, "boundary": boundary, "dist": dist} + needed_kwargs = infer_args_from_method(super(ConditionalPrior, self).__init__) + for kw in kwargs.copy(): + if kw not in needed_kwargs: + kwargs.pop(kw) + + super(ConditionalPrior, self).__init__(**kwargs, **reference_params) self._required_variables = None self.condition_func = condition_func self._reference_params = reference_params - self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) - self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) + self.__class__.__name__ = "Conditional{}".format(prior_class.__name__) + self.__class__.__qualname__ = "Conditional{}".format(prior_class.__qualname__) def sample(self, size=None, **required_variables): """Draw a sample from the prior @@ -202,7 +206,9 @@ def required_variables(self): return self._required_variables def get_instantiation_dict(self): - instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + superclass_args = infer_args_from_method(super(ConditionalPrior, self).__init__) + dict_with_properties = get_dict_with_properties(self) + instantiation_dict = {key: dict_with_properties[key] for key in superclass_args} for key, value in self.reference_params.items(): instantiation_dict[key] = value return instantiation_dict @@ -322,6 +328,10 @@ class ConditionalInterped(conditional_prior_factory(Interped)): pass +class ConditionalJointPrior(conditional_prior_factory(JointPrior)): + pass + + class DirichletElement(ConditionalBeta): r""" Single element in a dirichlet distribution From 7bf33995971991bbc87491ba00c50fa45f0cfe07 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:04:28 +0100 Subject: [PATCH 07/17] Make "mode" of MultivariateGaussianDist a setable property to use with ConditionalPriors --- bilby/core/prior/joint.py | 118 ++++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 8d2d9188..1bcdcc28 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -630,76 +630,83 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): ) def _rescale(self, samp, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) if mode is None: if self.nmodes == 1: mode = 0 else: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] + mode = random.rng.choice( + self.nmodes, size=len(samp), p=self.weights + ) samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] - ) + uniques = np.unique(mode) + if len(uniques) == 1: + unique = uniques[0] + samp = self.mus[unique] + self.sigmas[unique] * np.einsum( + "ij,kj->ik", samp * self.sqeigvalues[unique], self.eigvectors[unique] + ) + else: + for m in uniques: + mask = m == mode + samp[mask] = self.mus[m] + self.sigmas[m] * np.einsum( + "ij,kj->ik", samp[mask] * self.sqeigvalues[m], self.eigvectors[m] + ) + return samp def _sample(self, size, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) + + samps = np.zeros((size, len(self))) + outbound = np.ones(size, dtype=bool) if mode is None: if self.nmodes == 1: mode = 0 - else: - if size == 1: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] - else: - # pick modes - mode = [ - np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) - ] + while np.any(outbound): + # sample the multivariate Gaussian keys + vals = random.rng.uniform(0, 1, (np.sum(outbound), len(self))) - samps = np.zeros((size, len(self))) - for i in range(size): - inbound = False - while not inbound: - # sample the multivariate Gaussian keys - vals = random.rng.uniform(0, 1, len(self)) - - if isinstance(mode, list): - samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) - else: - samp = np.atleast_1d(self.rescale(vals, mode=mode)) - samps[i, :] = samp + if mode is None: + mode = random.rng.choice( + self.nmodes, size=np.sum(outbound), p=self.weights + ) - # check sample is in bounds (otherwise perform another draw) - outbound = False - for name, val in zip(self.names, samp): - if val < self.bounds[name][0] or val > self.bounds[name][1]: - outbound = True - break + samps[outbound] = np.atleast_1d(self.rescale(vals, mode=mode)) - if not outbound: - inbound = True + # check sample is in bounds and redraw those which are not + samps, outbound = self._check_samp(samps) return samps - def _ln_prob(self, samp, lnprob, outbounds): - for j in range(samp.shape[0]): - # loop over the modes and sum the probabilities - for i in range(self.nmodes): - # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() - z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + def _ln_prob(self, samp, lnprob, outbounds, **kwargs): + mode = kwargs.get("mode", self.mode) + + if mode is None: + for j in range(samp.shape[0]): + # loop over the modes and sum the probabilities + for i in range(self.nmodes): + # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() + z = (samp[j] - self.mus[i]) / self.sigmas[i] + lnprob[j] = np.logaddexp(lnprob[j], + self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i])) + else: + uniques = np.unique(np.asarray(mode, dtype=int)) + if len(uniques) == 1: + unique = uniques[0] + z = (samp[j] - self.mus[unique]) / self.sigmas[unique] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[j] = np.logaddexp(lnprob[j], self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) + else: + for m in uniques: + mask = mode == m + z = (samp[mask] - self.mus[m]) / self.sigmas[m] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[mask] = np.logaddexp(lnprob[mask], self.mvn[m].logpdf(z) - self.logprodsigmas[m]) # set out-of-bounds values to -inf lnprob[outbounds] = -np.inf @@ -739,6 +746,21 @@ def __eq__(self, other): return False return True + @property + def mode(self): + if hasattr(self, "_mode"): + return self._mode + else: + return None + + @mode.setter + def mode(self, mode): + if not np.isdtype(np.asarray(mode).dtype, "integral"): + raise ValueError("The mode to set must have integral data type.") + if np.any(mode >= self.nmodes) or np.any(mode < 0): + raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.") + self._mode = mode + class MultivariateNormalDist(MultivariateGaussianDist): """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution.""" From 9eb491f191ad73529e3ae0a27f2718b39730d7b2 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:07:15 +0100 Subject: [PATCH 08/17] Added TestCase for mode-setting of MultivariateGaussian with ConditionalPrior --- test/core/prior/conditional_test.py | 35 ++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda9..fcbd9d98 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -338,24 +338,47 @@ def test_rescale_with_joint_prior(self): cov = [[[0.03, 0.], [0., 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + names_2 = ["mvgvar_a", "mvgvar_b"] + mvg_dual_mode = bilby.core.prior.MultivariateGaussianDist( + names=names_2, + nmodes=2, + mus=[mu[0], (np.array(mu[0]) + np.ones_like(mu[0])).tolist()], + covs=[cov[0], cov[0]], + weights=[1, 2] + ) + + def condition_func_2(reference_params, var_0): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) + + def condition_func_1(reference_params, var_0, var_1): + return {"minimum": var_0, "maximum": var_1} + priordict = bilby.core.prior.ConditionalPriorDict( dict( var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, var_1=self.prior_1, - mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), - mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), + mvgvar_0=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), + mvgvar_1=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = list(self.test_sample.values()) + [0.4, 0.1] + [0.5, 0.2] + keys = list(self.test_sample.keys()) + names + names_2 res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(np.shape(res), (8,)) + self.assertListEqual([isinstance(r, float) for r in res], 8 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] From d0d597dd843901c35c06a05b555eed4fadd6c277 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 10:54:55 +0100 Subject: [PATCH 09/17] Avoid recreation of ConditionalPriorDict if not necessary --- bilby/core/prior/dict.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 0490d194..ed91e54f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -727,7 +727,10 @@ def sample_subset(self, keys=iter([]), size=None): if key not in keys and isinstance(self[key], DeltaFunction) ] use_keys = add_delta_keys + list(keys) - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) + if set(use_keys) == set(self.keys()): + subset_dict = self + else: + subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) if not subset_dict._resolved: raise IllegalConditionsException( "The current set of priors contains unresolvable conditions." From 4e647e9b1589d809e3b941e2235144c7bb9bb1b9 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 11:12:03 +0100 Subject: [PATCH 10/17] Ensure rescaling step works for the chosen set of keys and that the conditional properties of the priors can be set to arrays or loop over rescale values if not --- bilby/core/prior/dict.py | 28 +++++++++++++++++++++++----- test/core/prior/conditional_test.py | 5 +++-- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index ed91e54f..28257303 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -851,12 +851,30 @@ def rescale(self, keys, theta): """ keys = list(keys) theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) - self._check_resolved() - self._update_rescale_keys(keys) + if set(keys) == set(self.non_fixed_keys): + subset_dict = self + else: + subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) + if not subset_dict._resolved: + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + subset_dict._update_rescale_keys(keys) result = dict() - for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): - result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) - self[key].least_recently_sampled = result[key] + for key, index in zip(subset_dict.sorted_keys_without_fixed_parameters, subset_dict._rescale_indexes): + try: + result[key] = subset_dict[key].rescale(theta[index], **subset_dict.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = subset_dict.get_required_variables(key) + result[key] = np.zeros_like(theta[key]) + for i in range(len(theta[key])): + rvars = { + key: value[i] for key, value in required_variables.items() + } + result[key][i] = subset_dict[key].rescale(theta[index][i], **rvars) + subset_dict[key].least_recently_sampled = result[key] samples = [] for key in keys: # turns 0d-arrays into scalars diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index fcbd9d98..a80c5e40 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -401,10 +401,11 @@ def test_cdf(self): ) def test_rescale_illegal_conditions(self): - del self.conditional_priors["var_0"] + test_sample = self.test_sample.copy() + test_sample.pop("var_0") with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale( - keys=list(self.test_sample.keys()), + keys=list(test_sample.keys()), theta=list(self.test_sample.values()), ) From 80c2b56bbf2cdd4be9d5e8b32d7efad13845c786 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 11:23:40 +0100 Subject: [PATCH 11/17] Updated test case for conditional MultivariateGaussian to be more comprehensive --- test/core/prior/conditional_test.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index a80c5e40..8564ef89 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -351,7 +351,7 @@ def condition_func_2(reference_params, var_0): return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) def condition_func_1(reference_params, var_0, var_1): - return {"minimum": var_0, "maximum": var_1} + return {"minimum": var_0 - 1, "maximum": var_1 + 1} priordict = bilby.core.prior.ConditionalPriorDict( dict( @@ -385,6 +385,12 @@ def condition_func_1(reference_params, var_0, var_1): for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) self.assertListEqual(expected, res[0:4]) + res_sample = priordict.sample(1) + self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) + res_sample = priordict.sample(10) + self.assertListEqual([len(val) for val in res_sample.values()], [10] * len(res_sample.keys())) + lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) + self.assertEqual(len(lnprobs), 10) def test_cdf(self): """ From 7b65f8818d49f4d571f46fbff28b4448a9ef59d3 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 16:46:06 +0100 Subject: [PATCH 12/17] Added and updated Test cases --- test/core/prior/conditional_test.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 8564ef89..23f4f97d 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -353,8 +353,19 @@ def condition_func_2(reference_params, var_0): def condition_func_1(reference_params, var_0, var_1): return {"minimum": var_0 - 1, "maximum": var_1 + 1} + def condition_func_5(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_5 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_5, minimum=self.minimum, maximum=self.maximum + ) + priordict = bilby.core.prior.ConditionalPriorDict( dict( + var_5=prior_5, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, @@ -363,22 +374,20 @@ def condition_func_1(reference_params, var_0, var_1): condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), mvgvar_1=bilby.core.prior.ConditionalJointPrior( condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), - mvgvar_a=bilby.core.prior.ConditionalJointPrior( - condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", - minimum=self.minimum, maximum=self.maximum, mode=None), mvgvar_b=bilby.core.prior.ConditionalJointPrior( condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] + [0.5, 0.2] - keys = list(self.test_sample.keys()) + names + names_2 + ref_variables = self.test_sample.copy() + ref_variables.update({"mvgvar_0": 0.4, "mvgvar_1": 0.1, "mvgvar_a": 0.5, "mvgvar_b": 0.2, "var_5": 0.5}) + keys = list(self.test_sample.keys()) + names + names_2 + ["var_5"] res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (8,)) - self.assertListEqual([isinstance(r, float) for r in res], 8 * [True]) + self.assertEqual(np.shape(res), (9,)) + self.assertListEqual([isinstance(r, float) for r in res], 9 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] @@ -387,11 +396,36 @@ def condition_func_1(reference_params, var_0, var_1): self.assertListEqual(expected, res[0:4]) res_sample = priordict.sample(1) self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) - res_sample = priordict.sample(10) - self.assertListEqual([len(val) for val in res_sample.values()], [10] * len(res_sample.keys())) + res_sample = priordict.sample(1000) + self.assertListEqual([len(val) for val in res_sample.values()], [1000] * len(res_sample.keys())) lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) self.assertEqual(len(lnprobs), 10) + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + keys = set(priordict.keys()) - set(["mvgvar_a"]) + priordict.rescale(keys=keys, theta=ref_variables) + + def condition_func_6(reference_params, var_5): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_5)) + + priordict_unresolveable = bilby.core.prior.ConditionalPriorDict( + dict( + var_5=prior_5, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), + + ) + ) + self.assertEqual(priordict_unresolveable._resolved, False) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method. From 2863dce676879c57fe7b1c4a1938000d18e91de3 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:12:06 +0100 Subject: [PATCH 13/17] Fixed ConditionalPrior.__repr__ after changes to ConditionalPrior. --- bilby/core/prior/conditional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 1a6ca735..fcad7253 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -234,8 +234,8 @@ def __repr__(self): prior_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() instantiation_dict["condition_func"] = ".".join([ - instantiation_dict["condition_func"].__module__, - instantiation_dict["condition_func"].__name__ + self.condition_func.__module__, + self.condition_func.__name__ ]) args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) From d3f535b9bddfdd7598a5db93f6567c58c1b2ae7e Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:17:15 +0100 Subject: [PATCH 14/17] Improve ConditionalPriorDict.rescale and ConditionalPriorDict.sample: Remove necessity to initialize a whole new class instance for lists of dicts that do not span all keys of the ConditionalPriorDict --- bilby/core/prior/dict.py | 85 ++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 28257303..17ebc9b4 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -672,8 +672,6 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): self._conditional_keys = [] self._unconditional_keys = [] self._rescale_keys = [] - self._rescale_indexes = [] - self._least_recently_rescaled_keys = [] super(ConditionalPriorDict, self).__init__( dictionary=dictionary, filename=filename, @@ -720,40 +718,42 @@ def _check_conditions_resolved(self, key, sampled_keys): return conditions_resolved def sample_subset(self, keys=iter([]), size=None): + keys = list(keys) self.convert_floats_to_delta_functions() - add_delta_keys = [ - key - for key in self.keys() - if key not in keys and isinstance(self[key], DeltaFunction) - ] - use_keys = add_delta_keys + list(keys) - if set(use_keys) == set(self.keys()): - subset_dict = self - else: - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) + add_delta_keys = [] + for key in self.keys(): + if key not in keys and isinstance(self[key], DeltaFunction): + add_delta_keys.append(key) + + use_keys = add_delta_keys + keys + unconditional_use_keys = [key for key in self.unconditional_keys if key in use_keys] + sorted_conditional_use_keys = [key for key in self.conditional_keys if key in use_keys] + + for i, key in enumerate(sorted_conditional_use_keys): + if not self._check_conditions_resolved(key, unconditional_use_keys + sorted_conditional_use_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_use_keys = unconditional_use_keys + sorted_conditional_use_keys samples = dict() - for key in subset_dict.sorted_keys: + for key in sorted_use_keys: if key not in keys or isinstance(self[key], Constraint): continue if isinstance(self[key], Prior): try: - samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + samples[key] = self[key].sample( + size=size, **self.get_required_variables(key) ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) + required_variables = self.get_required_variables(key) samples[key] = np.zeros(size) for i in range(size): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) return samples @@ -850,31 +850,32 @@ def rescale(self, keys, theta): If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) - if set(keys) == set(self.non_fixed_keys): - subset_dict = self - else: - subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) - subset_dict._update_rescale_keys(keys) + + unconditional_keys = [key for key in self.unconditional_keys if key in keys] + sorted_conditional_keys = [key for key in self.conditional_keys if key in keys] + + for i, key in enumerate(sorted_conditional_keys): + if not self._check_conditions_resolved(key, unconditional_keys + sorted_conditional_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_keys = unconditional_keys + sorted_conditional_keys + theta = [theta[key] for key in sorted_keys] if isinstance(theta, dict) else list(theta) result = dict() - for key, index in zip(subset_dict.sorted_keys_without_fixed_parameters, subset_dict._rescale_indexes): + for key, vals in zip(sorted_keys, theta): try: - result[key] = subset_dict[key].rescale(theta[index], **subset_dict.get_required_variables(key)) + result[key] = self[key].rescale(vals, **self.get_required_variables(key)) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) - result[key] = np.zeros_like(theta[key]) - for i in range(len(theta[key])): + required_variables = self.get_required_variables(key) + result[key] = np.zeros_like(vals) + for i in range(len(vals)): rvars = { key: value[i] for key, value in required_variables.items() } - result[key][i] = subset_dict[key].rescale(theta[index][i], **rvars) - subset_dict[key].least_recently_sampled = result[key] + result[key][i] = self[key].rescale(vals[i], **rvars) + self[key].least_recently_sampled = result[key] samples = [] for key in keys: # turns 0d-arrays into scalars @@ -882,14 +883,6 @@ def rescale(self, keys, theta): samples.append(res) return samples - def _update_rescale_keys(self, keys): - if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [ - keys.index(element) - for element in self.sorted_keys_without_fixed_parameters - ] - self._least_recently_rescaled_keys = keys - def _prepare_evaluation(self, keys, theta): self._check_resolved() for key, value in zip(keys, theta): From 665e13692a8f87ed4536a923e056f18da2fbdda8 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:21:53 +0100 Subject: [PATCH 15/17] Handle JointPrior's better in rescale, sample, (ln)prob, and _check_conditions_resolved of (Condtional)PriorDict - keep track of dependencies of JointPriors necessary for their complete evaluation and handle cases where not all necessary keys are requested. --- bilby/core/prior/dict.py | 64 +++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 17ebc9b4..b9a4a97e 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None): samples[key] = self[key].sample(size=size) else: logger.debug("{} not a known prior.".format(key)) + # ensure that `reset_sampled()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples @property @@ -430,6 +433,27 @@ def non_fixed_keys(self): keys = [k for k in keys if k not in self.constraint_keys] return keys + @property + def jointprior_dependencies(self): + keys = self.keys() + keys = [k for k in keys if isinstance(self[k], JointPrior)] + dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys} + return dependencies + + def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func): + keys = set(keys) + dependencies = self.jointprior_dependencies + requested_jointpriors = set(dependencies).intersection() + missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]} + reset_dists = [] + for key in missing_dependencies: + dist = self[key].dist + if id(dist) in reset_dists: + pass + else: + getattr(dist, reset_func)() + reset_dists.append(id(dist)) + @property def fixed_keys(self): return [ @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + def prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which we want to have the probability of + normalized: bool + When False, disables calculation of constraint normalization factor + during prior probability computation. Default value is True. kwargs: The keyword arguments are passed directly to `np.prod` @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs): """ prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) - return self.check_prob(sample, prob) + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request") + return self.check_prob(sample, prob, normalized) - def check_prob(self, sample, prob): - ratio = self.normalize_constraint_factor(tuple(sample.keys())) + def check_prob(self, sample, prob, normalized=True): + if normalized: + ratio = self.normalize_constraint_factor(tuple(sample.keys())) + else: + ratio = 1 if np.all(prob == 0.0): return prob * ratio else: @@ -534,18 +567,18 @@ def check_prob(self, sample, prob): constrained_prob[keep] = prob[keep] * ratio return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which to calculate the log probability - axis: None or int - Axis along which the summation is performed normalized: bool When False, disables calculation of constraint normalization factor during prior probability computation. Default value is True. + kwargs: + The keyword arguments are passed directly to `np.prod` Returns ======= @@ -553,7 +586,11 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs) + + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -617,6 +654,7 @@ def rescale(self, keys, theta): for i, samps in enumerate(samples): # turns 0d-arrays into scalars samples[i] = np.squeeze(samps).tolist() + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def test_redundancy(self, key, disable_logging=False): @@ -715,6 +753,12 @@ def _check_conditions_resolved(self, key, sampled_keys): for k in self[key].required_variables: if k not in sampled_keys: conditions_resolved = False + break + elif isinstance(self[k], JointPrior): + dependencies = self.jointprior_dependencies[k] + if len(set(dependencies) - set(sampled_keys)) > 0: + conditions_resolved = False + break return conditions_resolved def sample_subset(self, keys=iter([]), size=None): @@ -756,6 +800,7 @@ def sample_subset(self, keys=iter([]), size=None): samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples def get_required_variables(self, key): @@ -796,6 +841,7 @@ def prob(self, sample, **kwargs): for key in sample ] prob = np.prod(res, **kwargs) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None, normalized=True): @@ -822,6 +868,7 @@ def ln_prob(self, sample, axis=None, normalized=True): for key in sample ] ln_prob = np.sum(res, axis=axis) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -881,6 +928,7 @@ def rescale(self, keys, theta): # turns 0d-arrays into scalars res = np.squeeze(result[key]).tolist() samples.append(res) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def _prepare_evaluation(self, keys, theta): From a15c8ed3c7f725776a44a4ca92b37e9145595b73 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 22 Nov 2024 17:52:57 +0100 Subject: [PATCH 16/17] Added test cases for new behavior --- test/core/prior/dict_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index d6e6239f..edc2fb20 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -287,6 +287,11 @@ def test_sample_subset_with_actual_subset(self): expected = dict(length=np.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + joint_prior = self.joint_prior_from_file + samples = joint_prior.sample_subset(keys=["testAbase"], size=size) + self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == []) + self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == []) + def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] @@ -320,6 +325,15 @@ def test_ln_prob(self): ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + def test_ln_prob_actual_subset(self): + joint_prior = self.joint_prior_from_file + keys = ["testAbase"] + samples = joint_prior.sample_subset(keys=keys, size=1) + lnprob = joint_prior.ln_prob(samples) + self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None) + self.assertTrue(lnprob == 0) + def test_rescale(self): theta = [0.5, 0.5, 0.5] expected = [ @@ -336,6 +350,16 @@ def test_rescale(self): ), ) + def test_rescale_actual_subset(self): + theta = [0.5] + keys = ["testAbase"] + joint_prior = self.joint_prior_from_file + samples = joint_prior.rescale(keys=keys, theta=theta) + print(joint_prior["testAbase"].dist._rescale_parameters) + self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None) + self.assertTrue(np.all(np.isnan(samples))) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method. From 2f71ffe2b2c9e43df2a61cbb2d32dffd7da18c7e Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Mon, 25 Nov 2024 10:25:17 +0100 Subject: [PATCH 17/17] Fix to mode setter --- bilby/core/prior/joint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 1bcdcc28..edb0c77b 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -755,7 +755,8 @@ def mode(self): @mode.setter def mode(self, mode): - if not np.isdtype(np.asarray(mode).dtype, "integral"): + mode = np.asarray(mode) + if not mode.dtype == int: raise ValueError("The mode to set must have integral data type.") if np.any(mode >= self.nmodes) or np.any(mode < 0): raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.")