diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 797cbd1c..fcad7253 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -3,13 +3,14 @@ from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from ..utils import infer_args_from_method, infer_parameters_from_function +from .joint import JointPrior +from ..utils import infer_args_from_method, infer_parameters_from_function, get_dict_with_properties def conditional_prior_factory(prior_class): class ConditionalPrior(prior_class): - def __init__(self, condition_func, name=None, latex_label=None, unit=None, - boundary=None, **reference_params): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, dist=None, + **reference_params): """ Parameters @@ -41,23 +42,26 @@ def condition_func(reference_params, y): See superclass boundary: str, optional See superclass + dist: BaseJointPriorDist, optional + See superclass reference_params: Initial values for attributes such as `minimum`, `maximum`. This differs on the `prior_class`, for example for the Gaussian prior this is `mu` and `sigma`. """ - if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, boundary=boundary, **reference_params) - else: - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, **reference_params) + kwargs = {"name": name, "latex_label": latex_label, "unit": unit, "boundary": boundary, "dist": dist} + needed_kwargs = infer_args_from_method(super(ConditionalPrior, self).__init__) + for kw in kwargs.copy(): + if kw not in needed_kwargs: + kwargs.pop(kw) + + super(ConditionalPrior, self).__init__(**kwargs, **reference_params) self._required_variables = None self.condition_func = condition_func self._reference_params = reference_params - self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) - self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) + self.__class__.__name__ = "Conditional{}".format(prior_class.__name__) + self.__class__.__qualname__ = "Conditional{}".format(prior_class.__qualname__) def sample(self, size=None, **required_variables): """Draw a sample from the prior @@ -202,7 +206,9 @@ def required_variables(self): return self._required_variables def get_instantiation_dict(self): - instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + superclass_args = infer_args_from_method(super(ConditionalPrior, self).__init__) + dict_with_properties = get_dict_with_properties(self) + instantiation_dict = {key: dict_with_properties[key] for key in superclass_args} for key, value in self.reference_params.items(): instantiation_dict[key] = value return instantiation_dict @@ -228,8 +234,8 @@ def __repr__(self): prior_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() instantiation_dict["condition_func"] = ".".join([ - instantiation_dict["condition_func"].__module__, - instantiation_dict["condition_func"].__name__ + self.condition_func.__module__, + self.condition_func.__name__ ]) args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) @@ -322,6 +328,10 @@ class ConditionalInterped(conditional_prior_factory(Interped)): pass +class ConditionalJointPrior(conditional_prior_factory(JointPrior)): + pass + + class DirichletElement(ConditionalBeta): r""" Single element in a dirichlet distribution diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index be3d543a..b9a4a97e 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None): samples[key] = self[key].sample(size=size) else: logger.debug("{} not a known prior.".format(key)) + # ensure that `reset_sampled()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples @property @@ -430,6 +433,27 @@ def non_fixed_keys(self): keys = [k for k in keys if k not in self.constraint_keys] return keys + @property + def jointprior_dependencies(self): + keys = self.keys() + keys = [k for k in keys if isinstance(self[k], JointPrior)] + dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys} + return dependencies + + def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func): + keys = set(keys) + dependencies = self.jointprior_dependencies + requested_jointpriors = set(dependencies).intersection() + missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]} + reset_dists = [] + for key in missing_dependencies: + dist = self[key].dist + if id(dist) in reset_dists: + pass + else: + getattr(dist, reset_func)() + reset_dists.append(id(dist)) + @property def fixed_keys(self): return [ @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk): factor = len(keep) / np.count_nonzero(keep) return factor - def prob(self, sample, **kwargs): + def prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which we want to have the probability of + normalized: bool + When False, disables calculation of constraint normalization factor + during prior probability computation. Default value is True. kwargs: The keyword arguments are passed directly to `np.prod` @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs): """ prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) - return self.check_prob(sample, prob) + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request") + return self.check_prob(sample, prob, normalized) - def check_prob(self, sample, prob): - ratio = self.normalize_constraint_factor(tuple(sample.keys())) + def check_prob(self, sample, prob, normalized=True): + if normalized: + ratio = self.normalize_constraint_factor(tuple(sample.keys())) + else: + ratio = 1 if np.all(prob == 0.0): return prob * ratio else: @@ -534,18 +567,18 @@ def check_prob(self, sample, prob): constrained_prob[keep] = prob[keep] * ratio return constrained_prob - def ln_prob(self, sample, axis=None, normalized=True): + def ln_prob(self, sample, normalized=True, **kwargs): """ Parameters ========== sample: dict Dictionary of the samples of which to calculate the log probability - axis: None or int - Axis along which the summation is performed normalized: bool When False, disables calculation of constraint normalization factor during prior probability computation. Default value is True. + kwargs: + The keyword arguments are passed directly to `np.prod` Returns ======= @@ -553,7 +586,11 @@ def ln_prob(self, sample, axis=None, normalized=True): Joint log probability of all the individual sample probabilities """ - ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs) + + # ensure that `reset_request()` of all JointPrior.dist + # with missing dependencies is called + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -600,18 +637,24 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is 1D, returns list of floats containing the rescaled sample. + If theta is 2D, returns list of lists containing the rescaled samples. """ - theta = list(theta) + theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta) samples = [] for key, units in zip(keys, theta): samps = self[key].rescale(units) - samples += list(np.asarray(samps).flatten()) + samples.append(samps) + for i, samps in enumerate(samples): + # turns 0d-arrays into scalars + samples[i] = np.squeeze(samps).tolist() + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples def test_redundancy(self, key, disable_logging=False): @@ -667,8 +710,6 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): self._conditional_keys = [] self._unconditional_keys = [] self._rescale_keys = [] - self._rescale_indexes = [] - self._least_recently_rescaled_keys = [] super(ConditionalPriorDict, self).__init__( dictionary=dictionary, filename=filename, @@ -712,42 +753,54 @@ def _check_conditions_resolved(self, key, sampled_keys): for k in self[key].required_variables: if k not in sampled_keys: conditions_resolved = False + break + elif isinstance(self[k], JointPrior): + dependencies = self.jointprior_dependencies[k] + if len(set(dependencies) - set(sampled_keys)) > 0: + conditions_resolved = False + break return conditions_resolved def sample_subset(self, keys=iter([]), size=None): + keys = list(keys) self.convert_floats_to_delta_functions() - add_delta_keys = [ - key - for key in self.keys() - if key not in keys and isinstance(self[key], DeltaFunction) - ] - use_keys = add_delta_keys + list(keys) - subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) - if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) + add_delta_keys = [] + for key in self.keys(): + if key not in keys and isinstance(self[key], DeltaFunction): + add_delta_keys.append(key) + + use_keys = add_delta_keys + keys + unconditional_use_keys = [key for key in self.unconditional_keys if key in use_keys] + sorted_conditional_use_keys = [key for key in self.conditional_keys if key in use_keys] + + for i, key in enumerate(sorted_conditional_use_keys): + if not self._check_conditions_resolved(key, unconditional_use_keys + sorted_conditional_use_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_use_keys = unconditional_use_keys + sorted_conditional_use_keys samples = dict() - for key in subset_dict.sorted_keys: + for key in sorted_use_keys: if key not in keys or isinstance(self[key], Constraint): continue if isinstance(self[key], Prior): try: - samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) + samples[key] = self[key].sample( + size=size, **self.get_required_variables(key) ) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. - required_variables = subset_dict.get_required_variables(key) + required_variables = self.get_required_variables(key) samples[key] = np.zeros(size) for i in range(size): rvars = { key: value[i] for key, value in required_variables.items() } - samples[key][i] = subset_dict[key].sample(**rvars) + samples[key][i] = self[key].sample(**rvars) else: logger.debug("{} not a known prior.".format(key)) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled") return samples def get_required_variables(self, key): @@ -788,6 +841,7 @@ def prob(self, sample, **kwargs): for key in sample ] prob = np.prod(res, **kwargs) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None, normalized=True): @@ -814,6 +868,7 @@ def ln_prob(self, sample, axis=None, normalized=True): for key in sample ] ln_prob = np.sum(res, axis=axis) + self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request") return self.check_ln_prob(sample, ln_prob, normalized=normalized) @@ -832,38 +887,50 @@ def rescale(self, keys, theta): ========== keys: list List of prior keys to be rescaled - theta: list - List of randomly drawn values on a unit cube associated with the prior keys + theta: dict or array-like + Randomly drawn values on a unit cube associated with the prior keys Returns ======= - list: List of floats containing the rescaled sample + list: + If theta is float for each key, returns list of floats containing the rescaled sample. + If theta is array-like for each key, returns list of lists containing the rescaled samples. """ keys = list(keys) - theta = list(theta) - self._check_resolved() - self._update_rescale_keys(keys) + + unconditional_keys = [key for key in self.unconditional_keys if key in keys] + sorted_conditional_keys = [key for key in self.conditional_keys if key in keys] + + for i, key in enumerate(sorted_conditional_keys): + if not self._check_conditions_resolved(key, unconditional_keys + sorted_conditional_keys[:i]): + raise IllegalConditionsException( + "The current set of priors contains unresolvable conditions." + ) + sorted_keys = unconditional_keys + sorted_conditional_keys + theta = [theta[key] for key in sorted_keys] if isinstance(theta, dict) else list(theta) result = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, vals in zip(sorted_keys, theta): + try: + result[key] = self[key].rescale(vals, **self.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = self.get_required_variables(key) + result[key] = np.zeros_like(vals) + for i in range(len(vals)): + rvars = { + key: value[i] for key, value in required_variables.items() + } + result[key][i] = self[key].rescale(vals[i], **rvars) self[key].least_recently_sampled = result[key] samples = [] for key in keys: - samples += list(np.asarray(result[key]).flatten()) + # turns 0d-arrays into scalars + res = np.squeeze(result[key]).tolist() + samples.append(res) + self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale") return samples - def _update_rescale_keys(self, keys): - if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [ - keys.index(element) - for element in self.sorted_keys_without_fixed_parameters - ] - self._least_recently_rescaled_keys = keys - def _prepare_evaluation(self, keys, theta): self._check_resolved() for key, value in zip(keys, theta): diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e..edb0c77b 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -63,8 +63,9 @@ def __init__(self, names, bounds=None): self.requested_parameters = dict() self.reset_request() - # a dictionary of the rescaled parameters - self.rescale_parameters = dict() + # a dictionary of the rescale(d) parameters + self._rescale_parameters = dict() + self._rescaled_parameters = dict() self.reset_rescale() # a list of sampled parameters @@ -94,7 +95,12 @@ def filled_rescale(self): Check if all the rescaled parameters have been filled. """ - return not np.any([val is None for val in self.rescale_parameters.values()]) + return not np.any([val is None for val in self._rescale_parameters.values()]) + + def set_rescale(self, key, values): + values = np.array(values) + self._rescale_parameters[key] = values + self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan def reset_rescale(self): """ @@ -102,7 +108,11 @@ def reset_rescale(self): """ for name in self.names: - self.rescale_parameters[name] = None + self._rescale_parameters[name] = None + self._rescaled_parameters[name] = None + + def get_rescaled(self, key): + return self._rescaled_parameters[key] def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) @@ -209,11 +219,9 @@ def _check_samp(self, value): raise ValueError("Array is the wrong shape") # check sample(s) is within bounds - outbounds = np.ones(samp.shape[0], dtype=bool) + outbounds = np.zeros(samp.shape[0], dtype=bool) for s, bound in zip(samp.T, self.bounds.values()): - outbounds = (s < bound[0]) | (s > bound[1]) - if np.any(outbounds): - break + outbounds += (s < bound[0]) | (s > bound[1]) return samp, outbounds def ln_prob(self, value): @@ -303,10 +311,11 @@ def rescale(self, value, **kwargs): Parameters ========== - value: array - A 1d vector sample (one for each parameter) drawn from a uniform + value: array or None + If given, a 1d vector sample (one for each parameter) drawn from a uniform distribution between 0 and 1, or a 2d NxM array of samples where N is the number of samples and M is the number of parameters. + If None, values previously set using BaseJointPriorDist.set_rescale() are used. kwargs: dict All keyword args that need to be passed to _rescale method, these keyword args are called in the JointPrior rescale methods for each parameter @@ -317,7 +326,11 @@ def rescale(self, value, **kwargs): An vector sample drawn from the multivariate Gaussian distribution. """ - samp = np.array(value) + if value is None: + samp = np.array(list(self._rescale_parameters.values())).T + else: + samp = np.array(value) + if len(samp.shape) == 1: samp = samp.reshape(1, self.num_vars) @@ -327,6 +340,11 @@ def rescale(self, value, **kwargs): raise ValueError("Array is the wrong shape") samp = self._rescale(samp, **kwargs) + if value is None: + for i, key in enumerate(self.names): + output = self.get_rescaled(key) + # update in-place for proper handling in PriorDict-instances + output[:] = samp[:, i] return np.squeeze(samp) def _rescale(self, samp, **kwargs): @@ -612,76 +630,83 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): ) def _rescale(self, samp, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) if mode is None: if self.nmodes == 1: mode = 0 else: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] + mode = random.rng.choice( + self.nmodes, size=len(samp), p=self.weights + ) samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[mode] + self.sigmas[mode] * np.einsum( - "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode] - ) + uniques = np.unique(mode) + if len(uniques) == 1: + unique = uniques[0] + samp = self.mus[unique] + self.sigmas[unique] * np.einsum( + "ij,kj->ik", samp * self.sqeigvalues[unique], self.eigvectors[unique] + ) + else: + for m in uniques: + mask = m == mode + samp[mask] = self.mus[m] + self.sigmas[m] * np.einsum( + "ij,kj->ik", samp[mask] * self.sqeigvalues[m], self.eigvectors[m] + ) + return samp def _sample(self, size, **kwargs): - try: - mode = kwargs["mode"] - except KeyError: - mode = None + mode = kwargs.get("mode", self.mode) + + samps = np.zeros((size, len(self))) + outbound = np.ones(size, dtype=bool) if mode is None: if self.nmodes == 1: mode = 0 - else: - if size == 1: - mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] - else: - # pick modes - mode = [ - np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) - ] + while np.any(outbound): + # sample the multivariate Gaussian keys + vals = random.rng.uniform(0, 1, (np.sum(outbound), len(self))) - samps = np.zeros((size, len(self))) - for i in range(size): - inbound = False - while not inbound: - # sample the multivariate Gaussian keys - vals = random.rng.uniform(0, 1, len(self)) - - if isinstance(mode, list): - samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) - else: - samp = np.atleast_1d(self.rescale(vals, mode=mode)) - samps[i, :] = samp + if mode is None: + mode = random.rng.choice( + self.nmodes, size=np.sum(outbound), p=self.weights + ) - # check sample is in bounds (otherwise perform another draw) - outbound = False - for name, val in zip(self.names, samp): - if val < self.bounds[name][0] or val > self.bounds[name][1]: - outbound = True - break + samps[outbound] = np.atleast_1d(self.rescale(vals, mode=mode)) - if not outbound: - inbound = True + # check sample is in bounds and redraw those which are not + samps, outbound = self._check_samp(samps) return samps - def _ln_prob(self, samp, lnprob, outbounds): - for j in range(samp.shape[0]): - # loop over the modes and sum the probabilities - for i in range(self.nmodes): - # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() - z = (samp[j] - self.mus[i]) / self.sigmas[i] - lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i]) + def _ln_prob(self, samp, lnprob, outbounds, **kwargs): + mode = kwargs.get("mode", self.mode) + + if mode is None: + for j in range(samp.shape[0]): + # loop over the modes and sum the probabilities + for i in range(self.nmodes): + # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode() + z = (samp[j] - self.mus[i]) / self.sigmas[i] + lnprob[j] = np.logaddexp(lnprob[j], + self.mvn[i].logpdf(z) - self.logprodsigmas[i] + np.log(self.weights[i])) + else: + uniques = np.unique(np.asarray(mode, dtype=int)) + if len(uniques) == 1: + unique = uniques[0] + z = (samp[j] - self.mus[unique]) / self.sigmas[unique] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[j] = np.logaddexp(lnprob[j], self.mvn[unique].logpdf(z) - self.logprodsigmas[unique]) + else: + for m in uniques: + mask = mode == m + z = (samp[mask] - self.mus[m]) / self.sigmas[m] + # don't multiply by the mode weight if the mode is given (ie. prob(mode|mode) = 1) + lnprob[mask] = np.logaddexp(lnprob[mask], self.mvn[m].logpdf(z) - self.logprodsigmas[m]) # set out-of-bounds values to -inf lnprob[outbounds] = -np.inf @@ -721,13 +746,29 @@ def __eq__(self, other): return False return True + @property + def mode(self): + if hasattr(self, "_mode"): + return self._mode + else: + return None + + @mode.setter + def mode(self, mode): + mode = np.asarray(mode) + if not mode.dtype == int: + raise ValueError("The mode to set must have integral data type.") + if np.any(mode >= self.nmodes) or np.any(mode < 0): + raise ValueError("The value of mode cannot be higher than the number of modes or smaller than zero.") + self._mode = mode + class MultivariateNormalDist(MultivariateGaussianDist): """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution.""" class JointPrior(Prior): - def __init__(self, dist, name=None, latex_label=None, unit=None): + def __init__(self, dist, name=None, latex_label=None, unit=None, **kwargs): """This defines the single parameter Prior object for parameters that belong to a JointPriorDist Parameters @@ -778,6 +819,17 @@ def maximum(self, maximum): self._maximum = maximum self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum) + def __setattr__(self, name, value): + # first check that the JointPrior has an explicit setter method for the attribute, which should take presedence + if hasattr(self.__class__, name) and getattr(self.__class__, name).fset is not None: + return super().__setattr__(name, value) + # then check if the BaseJointPriorDist-!subclass! has an explicit setter method for the attribute + elif hasattr(self, "dist") and hasattr(self.dist, name) and getattr(self.dist.__class__, name).fset is not None: + return self.dist.__setattr__(name, value) + # if not, use the default settattr + else: + return super().__setattr__(name, value) + def rescale(self, val, **kwargs): """ Scale a unit hypercube sample to the prior. @@ -790,19 +842,23 @@ def rescale(self, val, **kwargs): all kwargs passed to the dist.rescale method Returns ======= - float: - A sample from the prior parameter. + np.ndarray: + The samples from the prior parameter. If not all names in "dist" have been filled, + the array contains only np.nan. *This* specific array instance will be filled with + the rescaled value once all parameters have been requested """ - self.dist.rescale_parameters[self.name] = val + self.dist.set_rescale(self.name, val) if self.dist.filled_rescale(): - values = np.array(list(self.dist.rescale_parameters.values())).T - samples = self.dist.rescale(values, **kwargs) + self.dist.rescale(value=None, **kwargs) + output = self.dist.get_rescaled(self.name) self.dist.reset_rescale() - return samples else: - return [] # return empty list + output = self.dist.get_rescaled(self.name) + + # have to return raw output to conserve in-place modifications + return output def sample(self, size=1, **kwargs): """ diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda9..23f4f97d 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -338,30 +338,93 @@ def test_rescale_with_joint_prior(self): cov = [[[0.03, 0.], [0., 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + names_2 = ["mvgvar_a", "mvgvar_b"] + mvg_dual_mode = bilby.core.prior.MultivariateGaussianDist( + names=names_2, + nmodes=2, + mus=[mu[0], (np.array(mu[0]) + np.ones_like(mu[0])).tolist()], + covs=[cov[0], cov[0]], + weights=[1, 2] + ) + + def condition_func_2(reference_params, var_0): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_0)) + + def condition_func_1(reference_params, var_0, var_1): + return {"minimum": var_0 - 1, "maximum": var_1 + 1} + + def condition_func_5(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_5 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_5, minimum=self.minimum, maximum=self.maximum + ) + priordict = bilby.core.prior.ConditionalPriorDict( dict( + var_5=prior_5, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), var_3=self.prior_3, var_2=self.prior_2, var_0=self.prior_0, var_1=self.prior_1, - mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), - mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), + mvgvar_0=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_0", minimum=self.minimum, maximum=self.maximum), + mvgvar_1=bilby.core.prior.ConditionalJointPrior( + condition_func_1, dist=mvg, name="mvgvar_1", minimum=self.minimum, maximum=self.maximum), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_2, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), ) ) - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = self.test_sample.copy() + ref_variables.update({"mvgvar_0": 0.4, "mvgvar_1": 0.1, "mvgvar_a": 0.5, "mvgvar_b": 0.2, "var_5": 0.5}) + keys = list(self.test_sample.keys()) + names + names_2 + ["var_5"] res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(np.shape(res), (9,)) + self.assertListEqual([isinstance(r, float) for r in res], 9 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) self.assertListEqual(expected, res[0:4]) + res_sample = priordict.sample(1) + self.assertEqual(list(res_sample.keys()), priordict.sorted_keys_without_fixed_parameters) + res_sample = priordict.sample(1000) + self.assertListEqual([len(val) for val in res_sample.values()], [1000] * len(res_sample.keys())) + lnprobs = priordict.ln_prob(priordict.sample(10), axis=0) + self.assertEqual(len(lnprobs), 10) + + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + keys = set(priordict.keys()) - set(["mvgvar_a"]) + priordict.rescale(keys=keys, theta=ref_variables) + + def condition_func_6(reference_params, var_5): + return dict(mode=np.searchsorted(np.cumsum(np.array([1, 2]) / 3), var_5)) + + priordict_unresolveable = bilby.core.prior.ConditionalPriorDict( + dict( + var_5=prior_5, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_a", + minimum=self.minimum, maximum=self.maximum, mode=None), + mvgvar_b=bilby.core.prior.ConditionalJointPrior( + condition_func_6, dist=mvg_dual_mode, name="mvgvar_b", + minimum=self.minimum, maximum=self.maximum, mode=None), + + ) + ) + self.assertEqual(priordict_unresolveable._resolved, False) def test_cdf(self): """ @@ -378,10 +441,11 @@ def test_cdf(self): ) def test_rescale_illegal_conditions(self): - del self.conditional_priors["var_0"] + test_sample = self.test_sample.copy() + test_sample.pop("var_0") with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale( - keys=list(self.test_sample.keys()), + keys=list(test_sample.keys()), theta=list(self.test_sample.values()), ) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index d6e6239f..edc2fb20 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -287,6 +287,11 @@ def test_sample_subset_with_actual_subset(self): expected = dict(length=np.array([42.0, 42.0, 42.0])) self.assertTrue(np.array_equal(expected["length"], samples["length"])) + joint_prior = self.joint_prior_from_file + samples = joint_prior.sample_subset(keys=["testAbase"], size=size) + self.assertTrue(joint_prior["testAbase"].dist.sampled_parameters == []) + self.assertTrue(joint_prior["testBbase"].dist.sampled_parameters == []) + def test_sample_subset_constrained_as_array(self): size = 3 keys = ["mass", "speed"] @@ -320,6 +325,15 @@ def test_ln_prob(self): ) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) + def test_ln_prob_actual_subset(self): + joint_prior = self.joint_prior_from_file + keys = ["testAbase"] + samples = joint_prior.sample_subset(keys=keys, size=1) + lnprob = joint_prior.ln_prob(samples) + self.assertTrue(joint_prior["testAbase"].dist.requested_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist.requested_parameters["testBbase"] is None) + self.assertTrue(lnprob == 0) + def test_rescale(self): theta = [0.5, 0.5, 0.5] expected = [ @@ -336,6 +350,16 @@ def test_rescale(self): ), ) + def test_rescale_actual_subset(self): + theta = [0.5] + keys = ["testAbase"] + joint_prior = self.joint_prior_from_file + samples = joint_prior.rescale(keys=keys, theta=theta) + print(joint_prior["testAbase"].dist._rescale_parameters) + self.assertTrue(joint_prior["testAbase"].dist._rescale_parameters["testAbase"] is None) + self.assertTrue(joint_prior["testBbase"].dist._rescale_parameters["testBbase"] is None) + self.assertTrue(np.all(np.isnan(samples))) + def test_cdf(self): """ Test that the CDF method is the inverse of the rescale method.