-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolver.py
161 lines (128 loc) · 4.83 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import abc
import jax
from jax import numpy as jnp
from util.registry import register_category
get_predictor, register_predictor = register_category("predictors")
get_corrector, register_corrector = register_category("correctors")
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, sde):
super().__init__()
self.sde = sde
self.manifold = sde.manifold
@abc.abstractmethod
def update_fn(self, x, t, dt):
"""One update of the predictor.
"""
raise NotImplementedError()
class Corrector(abc.ABC):
"""The abstract class for a corrector algorithm."""
def __init__(self, sde, snr, n_steps):
super().__init__()
self.sde = sde
self.snr = snr
self.n_steps = n_steps
@abc.abstractmethod
def update_fn(self, x0, x, t):
"""One update of the corrector.
"""
raise NotImplementedError()
@register_predictor
#NOTE: Geodesic Random Walk
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde):
super().__init__(sde)
def update_fn(self, rng, x, t, dt):
z = self.sde.manifold.random_normal_tangent(
state=rng, base_point=x, n_samples=x.shape[0]
)[1].reshape(x.shape[0], -1)
drift, diffusion = self.sde.coefficients(x, t)
tangent_vector = jnp.einsum(
"...,...i,...->...i", diffusion, z, jnp.sqrt(jnp.abs(dt))
)
tangent_vector = tangent_vector + drift * dt[..., None]
x = self.manifold.exp(tangent_vec=tangent_vector, base_point=x)
return x, x
@register_corrector
class NoneCorrector(Corrector):
"""An empty corrector that does nothing."""
def __init__(self, sde, snr, n_steps):
pass
def update_fn(self, rng, x0, x, t):
return x, x
def get_pc_sampler(
sde,
N,
predictor="EulerMaruyamaPredictor",
corrector="NoneCorrector",
snr=0.0, n_steps=1, eps=1.0e-3
):
"""Create a Predictor-Corrector (PC) sampler.
"""
assert sde.approx
predictor = get_predictor(predictor)(sde)
corrector = get_corrector(corrector)(
sde, snr, n_steps
)
def pc_sampler(rng, x):
t0 = jnp.broadcast_to(sde.t0, x.shape[0])
tf = jnp.broadcast_to(sde.tf, x.shape[0])
timesteps = jnp.linspace(start=t0, stop=tf-eps, num=N, endpoint=True)
dt = (tf - t0) / N
def loop_body(i, val):
rng, x, x_mean, x_hist = val
t = timesteps[i]
rng, step_rng = jax.random.split(rng)
x, x_mean = corrector.update_fn(step_rng, x0, x, t)
rng, step_rng = jax.random.split(rng)
x, x_mean = predictor.update_fn(step_rng, x, t, dt)
x_hist = x_hist.at[i].set(x)
return rng, x, x_mean, x_hist
x_hist = jnp.zeros((N, *x.shape))
x0 = x
_, x, x_mean, x_hist = jax.lax.fori_loop(0, N, loop_body, (rng, x, x, x_hist))
return x_mean
return pc_sampler
class EulerMaruyamaTwoWayPredictor:
def __init__(self, mix, x0, xf, mask):
self.mix = mix
self.x0 = x0
self.xf = xf
self.mask = mask
self.manifold = mix.manifold
self.fsde = mix.bridge(xf)
self.bsde = mix.rev().bridge(x0)
def update_fn(self, rng, x, t, dt):
z = self.mix.manifold.random_normal_tangent(
state=rng, base_point=x, n_samples=x.shape[0]
)[1].reshape(x.shape[0], -1)
fdrift, fdiff = self.fsde.coefficients(x, t)
bdrift, bdiff = self.bsde.coefficients(x, t)
drift = jnp.einsum("...i,...->...i", fdrift, self.mask) + \
jnp.einsum("...i,...->...i", bdrift, ~self.mask)
diffusion = fdiff * self.mask + bdiff * ~self.mask
tangent_vector = jnp.einsum(
"...,...i,...->...i",
diffusion, z, jnp.abs(jnp.sqrt(dt))
)
tangent_vector = tangent_vector + jnp.einsum("...i,...->...i", drift, dt)
x = self.manifold.exp(tangent_vec=tangent_vector, base_point=x)
return x, x
def get_twoway_sampler(mix, N=10, eps=1.0e-3,):
def sampler(rng, x0, xf, t):
t_mask = t < 0.5
predictor = EulerMaruyamaTwoWayPredictor(mix, x0, xf, t_mask)
x = jnp.einsum("...i,...->...i", x0, t_mask) + \
jnp.einsum("...i,...->...i", xf, ~t_mask)
ts = t * t_mask + (1.-t) * ~t_mask
timesteps = jnp.linspace(start=mix.t0, stop=ts, num=N, endpoint=True)
dt = (ts - mix.t0) / N
def loop_body(i, val):
rng, x, x_mean = val
t = timesteps[i]
rng, step_rng = jax.random.split(rng)
x, x_mean = predictor.update_fn(step_rng, x, t, dt)
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, N, loop_body, (rng, x, x))
return x_mean
return sampler