-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistribution.py
150 lines (122 loc) · 5.53 KB
/
distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import jax
import numpy as np
from jax import numpy as jnp
from jax.scipy.special import logsumexp
import geomstats.backend as gs
from distrax import MultivariateNormalDiag
from geomstats.geometry.special_orthogonal import _SpecialOrthogonalMatrices, \
_SpecialOrthogonal3Vectors
class UniformDistribution:
"""Uniform density on compact manifold"""
def __init__(self, manifold):
self.manifold = manifold
def sample(self, rng, shape):
return self.manifold.random_uniform(state=rng, n_samples=shape[0])
def log_prob(self, z):
return -np.ones([z.shape[0]]) * self.manifold.log_volume
class Wrapped:
"""Wrapped normal density on compact manifold"""
def __init__(self, scale, batch_dims, manifold, mean_type, seed=0, **kwargs):
self.batch_dims = batch_dims
self.manifold = manifold
rng = jax.random.PRNGKey(seed)
rng, next_rng = jax.random.split(rng)
self.rng = rng
if mean_type == 'random':
self.mean = manifold.random_uniform(state=next_rng, n_samples=1)
elif mean_type == 'hyperbolic':
self.mean = jnp.expand_dims(self.manifold.identity, axis=0)
elif mean_type == 'mixture':
self.mean = kwargs['mean']
else:
raise NotImplementedError(f'mean_type: {mean_type} not implemented.')
self.scale = jnp.ones((self.mean.shape)) * scale if isinstance(scale, float) \
else jnp.array(scale)
def __iter__(self):
return self
def __next__(self):
return self.sample(self.rng, self.batch_dims)
def sample(self, rng, n_samples):
if not isinstance(n_samples, int):
n_samples = n_samples[0]
mean = self.mean
scale = self.scale
tangent_vec = self.manifold.random_normal_tangent(
rng, mean, n_samples
)[1]
tangent_vec = scale * tangent_vec
samples = self.manifold.exp(tangent_vec, mean)
return samples
# Used for SO3 and hyperbolic
def log_prob(self, samples):
tangent_vec = self.manifold.metric.log(samples, self.mean)
tangent_vec = self.manifold.metric.transpback0(self.mean, tangent_vec)
zero = jnp.zeros((self.manifold.dim))
# TODO: to refactor axis contenation / removal
if self.scale.shape[-1] == self.manifold.dim: # poincare
scale = self.scale
else: # hyperboloid
scale = self.scale[..., 1:]
norm_pdf = MultivariateNormalDiag(zero, scale).log_prob(tangent_vec)
logdetexp = self.manifold.metric.logdetexp(self.mean, samples)
return norm_pdf - logdetexp
class WrappedMixture:
"""Wrapped normal mixture density on compact manifold"""
def __init__(self, scale, batch_dims, manifold, mean_type, seed, rng=None, **kwargs):
self.batch_dims = batch_dims
self.manifold = manifold
rng = jax.random.PRNGKey(seed)
rng, next_rng = jax.random.split(rng)
self.rng = rng
if mean_type == 'random':
self.mean = manifold.random_uniform(state=next_rng, n_samples=4)
elif mean_type == 'so3':
assert isinstance(manifold, _SpecialOrthogonalMatrices)
means = []
self.centers = [[0.0, 0.0, 0.0], [0.0, 0.0, np.pi], [np.pi, 0.0, np.pi]]
for v in self.centers:
s = _SpecialOrthogonal3Vectors().matrix_from_tait_bryan_angles(np.array(v))
means.append(s)
self.mean = jnp.stack(means)
elif mean_type == 'poincare_disk':
self.mean = jnp.array([[-0.8, 0.0],[0.8, 0.0],[0.0, -0.8],[0.0, 0.8]])
elif mean_type == 'hyperboloid4':
mean = jnp.array([[-0.4, 0.0],[0.4, 0.0],[0.0, -0.4],[0.0, 0.4]])
self.mean = self.manifold._ball_to_extrinsic_coordinates(mean)
elif mean_type == 'hyperboloid6':
hex = [[0., 2.], [np.sqrt(3), 1.], [np.sqrt(3), -1.], [0., -2.],
[-np.sqrt(3), -1.], [-np.sqrt(3), 1.]]
mean = jnp.array(hex) * 0.3
self.mean = self.manifold._ball_to_extrinsic_coordinates(mean)
elif mean_type == 'test':
self.mean = kwargs['mean']
else:
raise NotImplementedError(f'mean_type: {mean_type} not implemented.')
self.scale = jnp.ones((self.mean.shape)) * scale if isinstance(scale, float) \
else jnp.array(scale)
def __iter__(self):
return self
def __next__(self):
return self.sample(self.rng, self.batch_dims)
def sample(self, rng, n_samples):
if not isinstance(n_samples, int):
n_samples = n_samples[0]
ks = jnp.arange(self.mean.shape[0])
self.rng, next_rng = jax.random.split(self.rng)
_, k = gs.random.choice(state=next_rng, a=ks, n=n_samples)
mean = self.mean[k]
scale = self.scale[k]
tangent_vec = self.manifold.random_normal_tangent(
next_rng, mean, n_samples
)[1]
tangent_vec = tangent_vec * scale
samples = self.manifold.exp(tangent_vec, mean)
return samples
def log_prob(self, samples):
def component_log_prob(mean, scale):
dist = Wrapped(scale, self.batch_dims, self.manifold,
'mixture', mean=mean)
return dist.log_prob(samples)
component_log_like = jax.vmap(component_log_prob)(self.mean, self.scale)
b = 1 / self.mean.shape[0] * jnp.ones_like(component_log_like)
return logsumexp(component_log_like, axis=0, b=b)