Skip to content

Commit

Permalink
qgs, updated 4DBp, updated 4DVar (#43)
Browse files Browse the repository at this point in the history
* qgs basic functionality

* Run all qgs calculations in numpy

* Remove extra methods from qgs.py

* Add qgs to testing workflow

* Set model params based on arg if provided

* QGS tests

* Fix jnp concatenate for sigma_bias

* Updates to 4dbackprop from experiments

* var4d bp with single grad calc

* 4dBP with jacobian computed wrt obs. Works, but very slow

* 4dvar backprop, integrated inner/outer loop

* 4DVar added to qgs branch, updates to 4dvar bp

* Can specify per-variable SD and Bias in observer

* Add callback raise value error for when loss value becomes nan. Exits scan early so that it doesn't hang

* Remove 4dbp jax config update (wasn't working for detecting NA loss values)

* Add check for loss growth, raise error if over 1000 times initial loss (can be tweaked)

* Lower loss growth limit, prevents hanging lax.scan more aggressively

* Use scane for outerloops over 4dvar

* Remove old outerloop code (w/o scan)

* Add pyqg-jax to data generators

* Attempt to implement parameterization, but probably not worth it

* Working IC creation and model run generation with pyqg_jax

* Decrease default num epochs to 3

* Adding hessian to backprop 4dvar - improves accuracy but computation takes forever

* Var4D BP with approximate hessian - fast and works well

* Remove jax hessian computation code

* Instead of using time sel matrix, loop over obs window indices. Same speed, same results, simpler to set up experiments

* Replace epochs with iters

* Adapt etkf to include equal to start of window using isclose

* Include losses in var4d

* Remove loss calculation from 4dvar (not using anyways)

* Exact hessian version of 4DVar-Backprop, as separate module for now

* Reformatting _utils

* Use np.rng for initial conditions and jax for times for pyqg_jax

* Add optax to environment

* Add chex version to environment.yml

* Skip aws tests

* Remove aws from observer test
  • Loading branch information
kysolvik authored Jun 4, 2024
1 parent 7e40a7d commit 0e3cbc5
Show file tree
Hide file tree
Showing 13 changed files with 1,157 additions and 96 deletions.
2 changes: 2 additions & 0 deletions dabench/dacycler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from ._var3d import Var3D
from ._etkf import ETKF
from ._var4d_backprop import Var4DBackprop
from ._var4d import Var4D

__all__ = [
'DACycler',
'Var3D',
'ETKF',
'Var4DBackprop',
'Var4D',
]
14 changes: 9 additions & 5 deletions dabench/dacycler/_etkf.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,16 @@ def cycle(self,
self.analysis_window = analysis_window
all_times = (jnp.repeat(start_time, timesteps)
+ (jnp.arange(0, timesteps)*self.delta_t))
# Get the obs vectors for each analysis window
all_filtered_idx = jnp.stack([jnp.where(
(obs_vector.times
>= jnp.round(cur_time - self.analysis_window/2, 3))
* (obs_vector.times
< jnp.round(cur_time + self.analysis_window/2, 3)))[0]
for cur_time in all_times])
# Greater than start of window
(obs_vector.times > cur_time - analysis_window/2)
# AND Less than end of window
* (obs_vector.times < cur_time + analysis_window/2)
# OR Equal to start of window
+ jnp.isclose(obs_vector.times, cur_time - analysis_window/2,
rtol=0)
)[0] for cur_time in all_times])
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
(input_state.values, obs_vector.values, obs_vector.times,
Expand Down
351 changes: 351 additions & 0 deletions dabench/dacycler/_var4d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
"""Class for Var 4D Data Assimilation Cycler object"""

import inspect

import numpy as np
import jax.numpy as jnp
import jax.scipy as jscipy
from jax import grad
import jax
from jax.scipy.sparse.linalg import bicgstab
from copy import deepcopy
from functools import partial

from dabench import dacycler, vector


class Var4D(dacycler.DACycler):
"""Class for building 4D DA Cycler
Attributes:
system_dim (int): System dimension.
delta_t (float): The timestep of the model (assumed uniform)
model_obj (dabench.Model): Forecast model object.
in_4d (bool): True for 4D data assimilation techniques (e.g. 4DVar).
Always True for Var4D.
ensemble (bool): True for ensemble-based data assimilation techniques
(ETKF). Always False for Var4D.
B (ndarray): Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
R (ndarray): Observation error covariance matrix. Shape
(obs_dim, obs_dim). If not provided, will be calculated
automatically.
H (ndarray): Observation operator with shape: (obs_dim, system_dim).
If not provided will be calculated automatically.
h (function): Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
solver (str): Name of solver to use. Default is 'bicgstab'.
n_outer_loops (int): Number of times to run through outer loop over
4DVar. Increasing this may result in higher accuracy but slower
performance. Default is 1.
steps_per_window (int): Number of timesteps per analysis window.
obs_window_indices (list): Timestep indices where observations fall
within each analysis window. For example, if analysis window is
0 - 0.05 with delta_t = 0.01 and observations fall at 0, 0.01,
0.02, 0.03, 0.04, and 0.05, obs_window_indices =
[0, 1, 2, 3, 4, 5].
"""

def __init__(self,
system_dim=None,
delta_t=None,
model_obj=None,
B=None,
R=None,
H=None,
h=None,
solver='bicgstab',
n_outer_loops=1,
steps_per_window=1,
obs_window_indices=[0],
**kwargs
):

self.steps_per_window = steps_per_window
self.obs_window_indices = obs_window_indices
self.n_outer_loops = n_outer_loops
self.solver = solver

super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
in_4d=True,
ensemble=False,
B=B, R=R, H=H, h=h)

def _calc_default_H(self, obs_values, obs_loc_indices):
H = jnp.zeros((obs_values[0].shape[0], self.system_dim))
H = H.at[jnp.arange(H.shape[0]), obs_loc_indices[0]
].set(1)
return H

def _calc_default_R(self, obs_values, obs_error_sd):
return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2)

def _calc_default_B(self):
return jnp.identity(self.system_dim)

def _make_outerloop_4d(self, xb0, H, B, Rinv,
obs_values, obs_window_indices,
n_steps):

def _outerloop_4d(x0, _):
# Get TLM and current forecast trajectory
# Based on current best guess for x0
M, x = self.model_obj.compute_tlm(
n_steps=n_steps,
state_vec=vector.StateVector(values=x0,
store_as_jax=True)
)

# 4D-Var inner loop
x0 = self._innerloop_4d(self.system_dim, x, xb0,
obs_values, H, B,
Rinv, M, obs_window_indices)

return x0, x0

return _outerloop_4d

def _cycle_obsop(self, xb0, obs_values, obs_loc_indices,
obs_error_sd, H=None, h=None, R=None, B=None,
obs_window_indices=None, n_steps=1):
if H is None and h is None:
if self.H is None:
if self.h is None:
H = self._calc_default_H(obs_values, obs_loc_indices)
else:
h = self.h
else:
H = self.H
if R is None:
if self.R is None:
R = self._calc_default_R(obs_values, obs_error_sd)
else:
R = self.R
if B is None:
if self.B is None:
B = self._calc_default_B()
else:
B = self.B

# Static Variables
Rinv = jscipy.linalg.inv(R)

# Best guess for x0 starts as background
x0 = deepcopy(xb0)

outerloop_4d_func = self._make_outerloop_4d(
xb0, H, B, Rinv, obs_values, obs_window_indices, n_steps)

x0, all_x0s = jax.lax.scan(outerloop_4d_func, init=x0,
xs=None, length=self.n_outer_loops)

# forecast
x = self.step_forecast(
n_steps=n_steps,
x0=vector.StateVector(values=x0, store_as_jax=True)
).values

return x, None

def step_cycle(self, x0, yo, obs_window_indices=[0],
H=None, h=None, R=None, B=None,
n_steps=1):
"""Perform one step of DA Cycle"""
if H is not None or h is None:
return self._cycle_obsop(
x0.values, yo.values, yo.location_indices, yo.error_sd,
H, R, B, obs_window_indices=obs_window_indices,
n_steps=n_steps)
else:
return self._cycle_obsop(
x0.values, yo.values, yo.location_indices, yo.error_sd, h,
R, B, obs_window_indices=obs_window_indices,
n_steps=n_steps)

def step_forecast(self, x0, n_steps=1):
"""Perform forecast using model object"""
if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args:
return self.model_obj.forecast(x0, n_steps=n_steps)
else:
if n_steps == 1:
return self.model_obj.forecast(x0)
else:
out = [x0]
xi = x0
for s in range(n_steps):
xi = self.model.forecast(xi)
out.append(xi)
return vector.StateVector(jnp.vstack(xi), store_as_jax=True)

@partial(jax.jit, static_argnums=[0,1])
def _innerloop_4d(self, system_dim, x, xb0, y, H, B, Rinv, M,
obs_window_indices=[0]):
"""4DVar innerloop
Args:
system_dim (int): The dimension of the system state.
x (ndarray): Current best guess for trajectory. Updated each outer
loop. (time_dim, system_dim)
xb0 (ndarray): Initial background estimate for initial conditions.
Stays constant throughout analysis cycle. Shape: (system_dim,)
y (ndarray): Time array of observation. Shape: (num_obs, obs_dim)
H (ndarray): Observation operator matrix. Shape:
(obs_dim, system_dim)
B (ndarray): Background/forecast error covariance matrix. Shape:
(system_dim, system_dim)
Rinv (ndarray): Inverted observation error covariance matrix. Shape:
(obs_dim, obs_dim)]
M (ndarray): List of TLMs for each model timestep. Shape:
(time_dim,system_dim, system_dim)
obs_window_indices (ndarray): Indices of observations w.r.t. model
timesteps in analysis window.
Returns:
xa0 (ndarray): inner loop estimate of optimal initial conditions.
Shape: (system_dim,)
"""
x0_last = x[0]

# Set up Variables
SumMtHtRinvD = jnp.zeros((system_dim, 1)) # b input
SumMtHtRinvHM = jnp.zeros_like(B) # A input

# Loop over observations
for i, j in enumerate(obs_window_indices):
# The Jb Term (A)
HM = H @ M[j, :, :]
MtHtRinv = HM.T @ Rinv
SumMtHtRinvHM += MtHtRinv @ HM

# The Jo Term (b)
D = y[i] - (H @ x[j])
SumMtHtRinvD += MtHtRinv @ D[:, None]

# Compute initial departure
db0 = xb0 - x0_last

# Solve Ax=b for the initial perturbation
dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B)

# New x0 guess is the last guess plus the analyzed delta
x0_new = x0_last + dx0.ravel()

return x0_new

@partial(jax.jit, static_argnums=0)
def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B):
"""Solve the 4D-Var linear optimization
Notes:
Solves Ax=b for x when:
A = B^{-1} + SumMtHtRinvHM
b = SumMtHtRinvD + db0[:,None]
"""

# Initialize b array
system_dim = B.shape[1]

# Set identity matrix
I_mat = jnp.identity(B.shape[0])

# Solve 4D-Var cost function
if self.solver == 'bicgstab':
# Compute A,b inputs to linear minimizer
b1 = B @ SumMtHtRinvD + db0[:, None]

A = I_mat + B @ SumMtHtRinvHM

dx0, _ = bicgstab(A, b1, x0=jnp.zeros((system_dim, 1)), tol=1e-05)

else:
raise ValueError("Solver not recognized. Options: 'bicgstab'")

return dx0

@partial(jax.jit, static_argnums=0)
def _cycle_and_forecast(self, cur_state_vals, filtered_idx):
obs_vals = self._obs_vector.values
obs_loc_indices = self._obs_vector.location_indices
obs_error_sd = self._obs_error_sd

cur_obs_vals = jax.lax.dynamic_slice_in_dim(obs_vals, filtered_idx[0],
len(filtered_idx))
cur_obs_loc_indices = jax.lax.dynamic_slice_in_dim(obs_loc_indices,
filtered_idx[0],
len(filtered_idx))
analysis, kh = self.step_cycle(
vector.StateVector(values=cur_state_vals, store_as_jax=True),
vector.ObsVector(values=cur_obs_vals,
location_indices=cur_obs_loc_indices,
error_sd=obs_error_sd,
store_as_jax=True),
n_steps=self.steps_per_window,
obs_window_indices=self.obs_window_indices)


return analysis[-1], analysis[:-1]

def cycle(self,
input_state,
start_time,
obs_vector,
obs_error_sd,
timesteps,
analysis_window,
analysis_time_in_window=None):
"""Perform DA cycle repeatedly, including analysis and forecast
Args:
input_state (vector.StateVector): Input state.
obs_vector (vector.ObsVector): Observations vector.
obs_error_sd (float): Standard deviation of observation error.
Typically not known, so provide a best-guess.
start_time (float or datetime-like): Starting time.
timesteps (int): Number of timesteps, in model time.
analysis_window (float): Time window from which to gather
observations for DA Cycle.
analysis_time_in_window (float): Where within analysis_window
to perform analysis. For example, 0.0 is the start of the
window. Default is None, which selects the middle of the
window.
Returns:
vector.StateVector of analyses and times.
"""

if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2

# Set up for jax.lax.scan, which is very fast
all_times = (
jnp.repeat(start_time + analysis_time_in_window, timesteps)
+ jnp.arange(0, timesteps*analysis_window,
analysis_window)
)
# Get the obs vectors for each analysis window
all_filtered_idx = jnp.stack([jnp.where(
# Greater than start of window
(obs_vector.times > cur_time - analysis_window/2)
# AND Less than end of window
* (obs_vector.times < cur_time + analysis_window/2)
# OR Equal to start of window end
+ jnp.isclose(obs_vector.times, cur_time - analysis_window/2,
rtol=0)
# OR Equal to end of window
+ jnp.isclose(obs_vector.times, cur_time + analysis_window/2,
rtol=0)
)[0] for cur_time in all_times])

self._obs_vector = obs_vector
self._obs_error_sd = obs_error_sd
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
init=input_state.values,
xs=all_filtered_idx)

return vector.StateVector(values=jnp.vstack(all_values),
store_as_jax=True)
Loading

0 comments on commit 0e3cbc5

Please sign in to comment.