Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add support for conditional JointPriors #864

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
2d259ca
Changed rescale-method of JointPrior to always return correct-size ar…
JasperMartins Nov 20, 2024
5ce0bbe
Small fix to rescale of JointPrior
JasperMartins Nov 22, 2024
4e1a3d2
For jointprior rescale, only cast to list once its save to loose muta…
JasperMartins Nov 25, 2024
35ac877
Fix to BaseJointPriorDist bound check
JasperMartins Nov 22, 2024
09bc43b
Allow setting of "dist" attributes through JointPrior
JasperMartins Nov 22, 2024
7c74683
Make ConditionalPrior ready for JointPrior
JasperMartins Nov 22, 2024
7bf3399
Make "mode" of MultivariateGaussianDist a setable property to use wit…
JasperMartins Nov 22, 2024
9eb491f
Added TestCase for mode-setting of MultivariateGaussian with Conditio…
JasperMartins Nov 22, 2024
d0d597d
Avoid recreation of ConditionalPriorDict if not necessary
JasperMartins Nov 22, 2024
4e647e9
Ensure rescaling step works for the chosen set of keys and that the c…
JasperMartins Nov 22, 2024
80c2b56
Updated test case for conditional MultivariateGaussian to be more com…
JasperMartins Nov 22, 2024
7b65f88
Added and updated Test cases
JasperMartins Nov 22, 2024
2863dce
Fixed ConditionalPrior.__repr__ after changes to ConditionalPrior.
JasperMartins Nov 22, 2024
d3f535b
Improve ConditionalPriorDict.rescale and ConditionalPriorDict.sample:…
JasperMartins Nov 22, 2024
665e136
Handle JointPrior's better in rescale, sample, (ln)prob, and _check_c…
JasperMartins Nov 22, 2024
a15c8ed
Added test cases for new behavior
JasperMartins Nov 22, 2024
2f71ffe
Fix to mode setter
JasperMartins Nov 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions bilby/core/prior/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \
SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \
LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac
from ..utils import infer_args_from_method, infer_parameters_from_function
from .joint import JointPrior
from ..utils import infer_args_from_method, infer_parameters_from_function, get_dict_with_properties


def conditional_prior_factory(prior_class):
class ConditionalPrior(prior_class):
def __init__(self, condition_func, name=None, latex_label=None, unit=None,
boundary=None, **reference_params):
def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, dist=None,
**reference_params):
"""

Parameters
Expand Down Expand Up @@ -41,23 +42,26 @@ def condition_func(reference_params, y):
See superclass
boundary: str, optional
See superclass
dist: BaseJointPriorDist, optional
See superclass
reference_params:
Initial values for attributes such as `minimum`, `maximum`.
This differs on the `prior_class`, for example for the Gaussian
prior this is `mu` and `sigma`.
"""
if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__):
super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
unit=unit, boundary=boundary, **reference_params)
else:
super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
unit=unit, **reference_params)
kwargs = {"name": name, "latex_label": latex_label, "unit": unit, "boundary": boundary, "dist": dist}
needed_kwargs = infer_args_from_method(super(ConditionalPrior, self).__init__)
for kw in kwargs.copy():
if kw not in needed_kwargs:
kwargs.pop(kw)

super(ConditionalPrior, self).__init__(**kwargs, **reference_params)

self._required_variables = None
self.condition_func = condition_func
self._reference_params = reference_params
self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__)
self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__)
self.__class__.__name__ = "Conditional{}".format(prior_class.__name__)
self.__class__.__qualname__ = "Conditional{}".format(prior_class.__qualname__)

def sample(self, size=None, **required_variables):
"""Draw a sample from the prior
Expand Down Expand Up @@ -202,7 +206,9 @@ def required_variables(self):
return self._required_variables

def get_instantiation_dict(self):
instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict()
superclass_args = infer_args_from_method(super(ConditionalPrior, self).__init__)
dict_with_properties = get_dict_with_properties(self)
instantiation_dict = {key: dict_with_properties[key] for key in superclass_args}
for key, value in self.reference_params.items():
instantiation_dict[key] = value
return instantiation_dict
Expand All @@ -228,8 +234,8 @@ def __repr__(self):
prior_name = self.__class__.__name__
instantiation_dict = self.get_instantiation_dict()
instantiation_dict["condition_func"] = ".".join([
instantiation_dict["condition_func"].__module__,
instantiation_dict["condition_func"].__name__
self.condition_func.__module__,
self.condition_func.__name__
])
args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
for key in instantiation_dict])
Expand Down Expand Up @@ -322,6 +328,10 @@ class ConditionalInterped(conditional_prior_factory(Interped)):
pass


class ConditionalJointPrior(conditional_prior_factory(JointPrior)):
pass


class DirichletElement(ConditionalBeta):
r"""
Single element in a dirichlet distribution
Expand Down
171 changes: 119 additions & 52 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def sample_subset(self, keys=iter([]), size=None):
samples[key] = self[key].sample(size=size)
else:
logger.debug("{} not a known prior.".format(key))
# ensure that `reset_sampled()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled")
return samples

@property
Expand All @@ -430,6 +433,27 @@ def non_fixed_keys(self):
keys = [k for k in keys if k not in self.constraint_keys]
return keys

@property
def jointprior_dependencies(self):
keys = self.keys()
keys = [k for k in keys if isinstance(self[k], JointPrior)]
dependencies = {k: list(set(self[k].dist.names) - set([k])) for k in keys}
return dependencies

def _reset_jointprior_dists_with_missed_dependencies(self, keys, reset_func):
keys = set(keys)
dependencies = self.jointprior_dependencies
requested_jointpriors = set(dependencies).intersection()
missing_dependencies = {value for key in requested_jointpriors for value in dependencies[key]}
reset_dists = []
for key in missing_dependencies:
dist = self[key].dist
if id(dist) in reset_dists:
pass
else:
getattr(dist, reset_func)()
reset_dists.append(id(dist))

@property
def fixed_keys(self):
return [
Expand Down Expand Up @@ -499,13 +523,16 @@ def _estimate_normalization(self, keys, min_accept, sampling_chunk):
factor = len(keep) / np.count_nonzero(keep)
return factor

def prob(self, sample, **kwargs):
def prob(self, sample, normalized=True, **kwargs):
"""

Parameters
==========
sample: dict
Dictionary of the samples of which we want to have the probability of
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
kwargs:
The keyword arguments are passed directly to `np.prod`

Expand All @@ -516,10 +543,16 @@ def prob(self, sample, **kwargs):
"""
prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs)

return self.check_prob(sample, prob)
# ensure that `reset_request()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), reset_func="reset_request")
return self.check_prob(sample, prob, normalized)

def check_prob(self, sample, prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
def check_prob(self, sample, prob, normalized=True):
if normalized:
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
else:
ratio = 1
if np.all(prob == 0.0):
return prob * ratio
else:
Expand All @@ -534,26 +567,30 @@ def check_prob(self, sample, prob):
constrained_prob[keep] = prob[keep] * ratio
return constrained_prob

def ln_prob(self, sample, axis=None, normalized=True):
def ln_prob(self, sample, normalized=True, **kwargs):
"""

Parameters
==========
sample: dict
Dictionary of the samples of which to calculate the log probability
axis: None or int
Axis along which the summation is performed
normalized: bool
When False, disables calculation of constraint normalization factor
during prior probability computation. Default value is True.
kwargs:
The keyword arguments are passed directly to `np.prod`

Returns
=======
float or ndarray:
Joint log probability of all the individual sample probabilities

"""
ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis)
ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], **kwargs)

# ensure that `reset_request()` of all JointPrior.dist
# with missing dependencies is called
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)

Expand Down Expand Up @@ -600,18 +637,24 @@ def rescale(self, keys, theta):
==========
keys: list
List of prior keys to be rescaled
theta: list
List of randomly drawn values on a unit cube associated with the prior keys
theta: dict or array-like
Randomly drawn values on a unit cube associated with the prior keys

Returns
=======
list: List of floats containing the rescaled sample
list:
If theta is 1D, returns list of floats containing the rescaled sample.
If theta is 2D, returns list of lists containing the rescaled samples.
"""
theta = list(theta)
theta = [theta[key] for key in keys] if isinstance(theta, dict) else list(theta)
samples = []
for key, units in zip(keys, theta):
samps = self[key].rescale(units)
samples += list(np.asarray(samps).flatten())
samples.append(samps)
for i, samps in enumerate(samples):
# turns 0d-arrays into scalars
samples[i] = np.squeeze(samps).tolist()
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale")
return samples

def test_redundancy(self, key, disable_logging=False):
Expand Down Expand Up @@ -667,8 +710,6 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None):
self._conditional_keys = []
self._unconditional_keys = []
self._rescale_keys = []
self._rescale_indexes = []
self._least_recently_rescaled_keys = []
super(ConditionalPriorDict, self).__init__(
dictionary=dictionary,
filename=filename,
Expand Down Expand Up @@ -712,42 +753,54 @@ def _check_conditions_resolved(self, key, sampled_keys):
for k in self[key].required_variables:
if k not in sampled_keys:
conditions_resolved = False
break
elif isinstance(self[k], JointPrior):
dependencies = self.jointprior_dependencies[k]
if len(set(dependencies) - set(sampled_keys)) > 0:
conditions_resolved = False
break
return conditions_resolved

def sample_subset(self, keys=iter([]), size=None):
keys = list(keys)
self.convert_floats_to_delta_functions()
add_delta_keys = [
key
for key in self.keys()
if key not in keys and isinstance(self[key], DeltaFunction)
]
use_keys = add_delta_keys + list(keys)
subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys})
if not subset_dict._resolved:
raise IllegalConditionsException(
"The current set of priors contains unresolvable conditions."
)
add_delta_keys = []
for key in self.keys():
if key not in keys and isinstance(self[key], DeltaFunction):
add_delta_keys.append(key)

use_keys = add_delta_keys + keys
unconditional_use_keys = [key for key in self.unconditional_keys if key in use_keys]
sorted_conditional_use_keys = [key for key in self.conditional_keys if key in use_keys]

for i, key in enumerate(sorted_conditional_use_keys):
if not self._check_conditions_resolved(key, unconditional_use_keys + sorted_conditional_use_keys[:i]):
raise IllegalConditionsException(
"The current set of priors contains unresolvable conditions."
)
sorted_use_keys = unconditional_use_keys + sorted_conditional_use_keys
samples = dict()
for key in subset_dict.sorted_keys:
for key in sorted_use_keys:
if key not in keys or isinstance(self[key], Constraint):
continue
if isinstance(self[key], Prior):
try:
samples[key] = subset_dict[key].sample(
size=size, **subset_dict.get_required_variables(key)
samples[key] = self[key].sample(
size=size, **self.get_required_variables(key)
)
except ValueError:
# Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
# If that is the case, we sample each sample individually.
required_variables = subset_dict.get_required_variables(key)
required_variables = self.get_required_variables(key)
samples[key] = np.zeros(size)
for i in range(size):
rvars = {
key: value[i] for key, value in required_variables.items()
}
samples[key][i] = subset_dict[key].sample(**rvars)
samples[key][i] = self[key].sample(**rvars)
else:
logger.debug("{} not a known prior.".format(key))
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_sampled")
return samples

def get_required_variables(self, key):
Expand Down Expand Up @@ -788,6 +841,7 @@ def prob(self, sample, **kwargs):
for key in sample
]
prob = np.prod(res, **kwargs)
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_prob(sample, prob)

def ln_prob(self, sample, axis=None, normalized=True):
Expand All @@ -814,6 +868,7 @@ def ln_prob(self, sample, axis=None, normalized=True):
for key in sample
]
ln_prob = np.sum(res, axis=axis)
self._reset_jointprior_dists_with_missed_dependencies(sample.keys(), "reset_request")
return self.check_ln_prob(sample, ln_prob,
normalized=normalized)

Expand All @@ -832,38 +887,50 @@ def rescale(self, keys, theta):
==========
keys: list
List of prior keys to be rescaled
theta: list
List of randomly drawn values on a unit cube associated with the prior keys
theta: dict or array-like
Randomly drawn values on a unit cube associated with the prior keys

Returns
=======
list: List of floats containing the rescaled sample
list:
If theta is float for each key, returns list of floats containing the rescaled sample.
If theta is array-like for each key, returns list of lists containing the rescaled samples.
"""
keys = list(keys)
theta = list(theta)
self._check_resolved()
self._update_rescale_keys(keys)

unconditional_keys = [key for key in self.unconditional_keys if key in keys]
sorted_conditional_keys = [key for key in self.conditional_keys if key in keys]

for i, key in enumerate(sorted_conditional_keys):
if not self._check_conditions_resolved(key, unconditional_keys + sorted_conditional_keys[:i]):
raise IllegalConditionsException(
"The current set of priors contains unresolvable conditions."
)
sorted_keys = unconditional_keys + sorted_conditional_keys
theta = [theta[key] for key in sorted_keys] if isinstance(theta, dict) else list(theta)
result = dict()
for key, index in zip(
self.sorted_keys_without_fixed_parameters, self._rescale_indexes
):
result[key] = self[key].rescale(
theta[index], **self.get_required_variables(key)
)
for key, vals in zip(sorted_keys, theta):
try:
result[key] = self[key].rescale(vals, **self.get_required_variables(key))
except ValueError:
# Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
# If that is the case, we sample each sample individually.
required_variables = self.get_required_variables(key)
result[key] = np.zeros_like(vals)
for i in range(len(vals)):
rvars = {
key: value[i] for key, value in required_variables.items()
}
result[key][i] = self[key].rescale(vals[i], **rvars)
self[key].least_recently_sampled = result[key]
samples = []
for key in keys:
samples += list(np.asarray(result[key]).flatten())
# turns 0d-arrays into scalars
res = np.squeeze(result[key]).tolist()
samples.append(res)
self._reset_jointprior_dists_with_missed_dependencies(keys, "reset_rescale")
return samples

def _update_rescale_keys(self, keys):
if not keys == self._least_recently_rescaled_keys:
self._rescale_indexes = [
keys.index(element)
for element in self.sorted_keys_without_fixed_parameters
]
self._least_recently_rescaled_keys = keys

def _prepare_evaluation(self, keys, theta):
self._check_resolved()
for key, value in zip(keys, theta):
Expand Down
Loading
Loading