Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/xarray base #56

Merged
merged 65 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5ecbef2
Data().generate() now returns an Xarray 'state vector'
kysolvik Sep 18, 2024
f5f1bf4
Remove vector class import for now, replacing vector objects with sim…
kysolvik Sep 18, 2024
8f60436
Update sqgturb to work with xarray. Returns in real space, not spectral.
kysolvik Sep 18, 2024
9c33a8c
Updated observer with xarray object, but only basic functionality is …
kysolvik Sep 20, 2024
8264485
3DVar working with xarray:
kysolvik Sep 20, 2024
b728551
ETKF working with xarray, initial commit
kysolvik Sep 23, 2024
4ccbeba
Sort observed locations
kysolvik Sep 23, 2024
e41e60d
Data times are stored as base numpy arrays, since xarray coords canno…
kysolvik Sep 23, 2024
4e8b970
Reinserting netcdf utils into data class
kysolvik Sep 23, 2024
99b1b3c
gcp with xarray
kysolvik Sep 23, 2024
9eddcf1
ETKF can handle irregular obs now, but doesn't have proper indices info
kysolvik Sep 23, 2024
bffcfbc
Use system_index variable for H
kysolvik Sep 24, 2024
0cbd940
Observer adds integer indices for easier H calculation
kysolvik Sep 24, 2024
e023257
Allow datavars to not match observed vars
kysolvik Sep 25, 2024
6e9de1a
Updated system index and remove sort from the observer
kysolvik Sep 25, 2024
2f87b45
Making some dacycler methods part of parent class for ease of mainten…
kysolvik Sep 26, 2024
ccf7e03
Initial version of xarray var4dBP, needs testing and cleaning
kysolvik Sep 26, 2024
3c09ef3
Cleaned var4dbp using xarray
kysolvik Sep 26, 2024
cbc308f
All dacyclers working with xarray, but possibly some accuracy issues …
kysolvik Sep 27, 2024
1690992
State Vec has delta_t attribute and M is provided as xarray
kysolvik Sep 27, 2024
8b81d55
Working RC Model with xarray
kysolvik Sep 27, 2024
e3cd320
Fixed generator extra step rounding error
kysolvik Sep 27, 2024
19572e3
Observer can accept random_time_density now
kysolvik Sep 27, 2024
e5bfe0e
Add permissible xarray jax to pyproject toml
kysolvik Sep 27, 2024
789019a
Remove all vector module imports
kysolvik Sep 27, 2024
f0f982d
Rename i to index for toy data generators
kysolvik Sep 27, 2024
e8bd897
GCP system dim now specified
kysolvik Sep 27, 2024
17cab8f
Observer can accept list of error_sds, and now samples with replaceme…
kysolvik Sep 27, 2024
43c9d2b
Fixed issue with missing time offset for 4dvar and 4dvarBP
kysolvik Sep 27, 2024
5a916f5
Reassign coords to match input state within dacycler
kysolvik Sep 30, 2024
6a2e44e
Apply coord reassing for 4dvar cycler too
kysolvik Sep 30, 2024
2758b82
Remove unnecessary print from 3dvar
kysolvik Sep 30, 2024
a515768
Reassign coords for outer loop carry instead of dropping time
kysolvik Sep 30, 2024
96721df
Updated gcp to properly assign system_dim
kysolvik Sep 30, 2024
2c7686d
Assign system_dim as attr, not coord
kysolvik Sep 30, 2024
974613c
Fix typo for xarray_jax git repo
kysolvik Sep 30, 2024
3fdfabd
XArray accessors for helper methods
kysolvik Oct 2, 2024
25bb39b
Update xarray accessor methods: ds.dab.flatten() and da.dab.unflatten()
kysolvik Oct 3, 2024
0d99873
NeuralGCM model with configuration YAML
kysolvik Oct 3, 2024
2623ce7
Neuralgcm forecast returns last step and full forecast tuple
kysolvik Oct 3, 2024
f185b63
Remove old import_xarray
kysolvik Oct 4, 2024
04cb49c
Add date and variable filtering to load netcf
kysolvik Oct 4, 2024
5f506ab
Add option for data split in fraction
kysolvik Oct 4, 2024
143d922
Updated load netcdf with filtering and data base test
kysolvik Oct 4, 2024
45d0988
Updated LE calcs for xarray
kysolvik Oct 4, 2024
ebe0154
QGS with xarray output
kysolvik Oct 4, 2024
0a36581
All off-line data tests working with xarray
kysolvik Oct 4, 2024
d7db346
Partially working DA tests (3dVar working)
kysolvik Oct 4, 2024
5151378
Store error sd in obs vec
kysolvik Oct 4, 2024
8cb35fc
ETKF and var4dBP tests passing
kysolvik Oct 6, 2024
4385a5d
Updated vals for 4dvar testing (baed on previous stable vrsion, were …
kysolvik Oct 7, 2024
4fd0f74
Updated presaved 4dvar test vals with consistent model delta_t (previ…
kysolvik Oct 7, 2024
06cab3d
4dvar method need different calc_default_R method (single observation…
kysolvik Oct 7, 2024
7bfc9e0
Test without specifying R
kysolvik Oct 7, 2024
d055622
Updated enso_indices (and tests) for xarray
kysolvik Oct 7, 2024
0068c7f
Removing vector class tests
kysolvik Oct 7, 2024
34cd96c
Updated gcp tests
kysolvik Oct 7, 2024
9317353
Smaller GCP tests to speed up downloads
kysolvik Oct 7, 2024
5ac3737
Smaller GCP tests, all passing
kysolvik Oct 7, 2024
107744b
Updated model base test for xarray, passing
kysolvik Oct 7, 2024
502d549
Observer tests updated and sqgturb system dim set in real space
kysolvik Oct 8, 2024
d7f2196
Skipping obsop tests (temporary, not used in any examples)
kysolvik Oct 8, 2024
a542566
Update README with xarray, added observer example
kysolvik Oct 8, 2024
d87581c
Correct system_dim in sqgturb test
kysolvik Oct 8, 2024
f7bf221
Merge branch 'main' into dev/xarray_base
kysolvik Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dabench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import data, vector, model, observer, obsop, dacycler, _suppl_data
from . import data, model, observer, obsop, dacycler, _suppl_data
230 changes: 184 additions & 46 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Base class for Data Assimilation Cycler object (DACycler)"""

from dabench import vector
import numpy as np
import jax.numpy as jnp
import jax
import xarray as xr
import xarray_jax as xj

import dabench.dacycler._utils as dac_utils

class DACycler():
"""Base class for DACycler object
Expand Down Expand Up @@ -37,6 +41,7 @@ def __init__(self,
R=None,
H=None,
h=None,
analysis_time_in_window=None
):

self.h = h
Expand All @@ -48,15 +53,122 @@ def __init__(self,
self.system_dim = system_dim
self.delta_t = delta_t
self.model_obj = model_obj
self.analysis_time_in_window = analysis_time_in_window


def _calc_default_H(self, obs_values, obs_loc_indices):
H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim))
H = H.at[jnp.arange(H.shape[0]),
obs_loc_indices.flatten(),
].set(1)
return H

def _calc_default_R(self, obs_values, obs_error_sd):
return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2)

def _calc_default_B(self):
"""If B is not provided, identity matrix with shape (system_dim, system_dim."""
return jnp.identity(self.system_dim)

def _step_forecast(self, xa, n_steps=1):
"""Perform forecast using model object"""
return self.model_obj.forecast(xa, n_steps=n_steps)

def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask,
H=None, h=None, R=None, B=None, **kwargs):
if H is not None or h is None:
vals = self._cycle_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, H, R, B, **kwargs)
return vals
else:
raise ValueError(
'Only linear obs operators (H) are supported right now.')
vals = self._cycle_general_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, h, R, B, **kwargs)
return vals

def _cycle_and_forecast(self, cur_state, filtered_idx):
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
cur_time = cur_state['_cur_time'].data
cur_state = cur_state.drop_vars(['_cur_time'])
obs_time_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

# 2. Calculate analysis
cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get()
cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get()
cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool)
cur_obs_time_mask = jnp.repeat(obs_time_mask, cur_obs_vals.shape[-1])
analysis = self._step_cycle(
cur_state,
cur_obs_vals,
cur_obs_loc_indices,
obs_loc_mask=cur_obs_loc_mask,
obs_time_mask=cur_obs_time_mask
)
# 3. Forecast next timestep
next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window)
next_state = next_state.assign(
_cur_time = cur_time + self.analysis_window
).assign_coords(
cur_state.coords)

return xj.from_xarray(next_state), forecast_states

def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
cur_time = cur_state['_cur_time'].data
cur_state = cur_state.drop_vars(['_cur_time'])
obs_time_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get()
cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get()
cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1)
cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1)

# Calculate obs window indices: closest model timesteps that match obs
obs_window_indices =jnp.array([
jnp.argmin(
jnp.abs(obs_time - (cur_time + self._model_timesteps))
) for obs_time in cur_obs_times
])

# 2. Calculate analysis
analysis = self._step_cycle(
cur_state,
cur_obs_vals,
cur_obs_loc_indices,
obs_loc_mask=cur_obs_loc_mask,
obs_time_mask=obs_time_mask,
obs_window_indices=obs_window_indices
)

# 3. Forecast forward
next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window)
next_state = next_state.assign(
_cur_time = cur_time + self.analysis_window
).assign_coords(
cur_state.coords)

return xj.from_xarray(next_state), forecast_states

def cycle(self,
input_state,
start_time,
obs_vector,
n_cycles,
analysis_window,
obs_error_sd=None,
analysis_window=0.2,
analysis_time_in_window=None,
return_forecast=False):
return_forecast=False
):
"""Perform DA cycle repeatedly, including analysis and forecast

Args:
Expand All @@ -79,52 +191,78 @@ def cycle(self,
vector.StateVector of analyses and times.
"""

# These could be different if observer doesn't observe all variables
# For now, making them the same
self._observed_vars = obs_vector['variable'].values
self._data_vars = list(input_state.data_vars)

if obs_error_sd is None:
obs_error_sd = obs_vector.error_sd

self.analysis_window = analysis_window

# If don't specify analysis_time_in_window, is assumed to be middle
if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2
if self.analysis_time_in_window is None and analysis_time_in_window is None:
analysis_time_in_window = self.analysis_window/2
else:
analysis_time_in_window = self.analysis_time_in_window

# Steps per window + 1 to include start
self.steps_per_window = round(analysis_window/self.delta_t) + 1
self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t

# Time offset from middle of time window, for gathering observations
_time_offset = (analysis_window/2) - analysis_time_in_window

# Number of model steps to run per window
steps_per_window = round(analysis_window/self.delta_t) + 1

# For storing outputs
all_output_states = []
all_times = []
cur_time = start_time
cur_state = input_state

for i in range(n_cycles):
# 1. Filter observations to inside analysis window
window_middle = cur_time + _time_offset
window_start = window_middle - analysis_window/2
window_end = window_middle + analysis_window/2
obs_vec_timefilt = obs_vector.filter_times(
window_start, window_end
)

if obs_vec_timefilt.values.shape[0] > 0:
# 2. Calculate analysis
analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt)
# 3. Forecast through analysis window
forecast_states = self._step_forecast(analysis,
n_steps=steps_per_window)
# 4. Save outputs
if return_forecast:
# Append forecast to current state, excluding last step
all_output_states.append(forecast_states.values[:-1])
all_times.append(
np.arange(steps_per_window-1)*self.delta_t + cur_time
)
else:
all_output_states.append(analysis.values[np.newaxis])
all_times.append([cur_time])

# Starting point for next cycle is last step of forecast
cur_state = forecast_states[-1]
cur_time += analysis_window

return vector.StateVector(values=np.concatenate(all_output_states),
times=np.concatenate(all_times))
# Set up for jax.lax.scan, which is very fast
all_times = dac_utils._get_all_times(
start_time,
analysis_window,
n_cycles)


if self.steps_per_window is None:
self.steps_per_window = round(analysis_window/self.delta_t) + 1
self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t
# Get the obs vectors for each analysis window
all_filtered_idx = dac_utils._get_obs_indices(
obs_times=jnp.array(obs_vector.time.values),
analysis_times=all_times+_time_offset,
start_inclusive=True,
end_inclusive=self.in_4d,
analysis_window=analysis_window
)
input_state = input_state.assign(_cur_time=start_time)

all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True)
self._obs_vector=obs_vector
self.obs_error_sd = obs_error_sd
if obs_vector.stationary_observers:
self._obs_loc_masks = jnp.ones(
obs_vector[self._observed_vars].to_array().shape, dtype=bool)
else:
self._obs_loc_masks = ~np.isnan(
obs_vector[self._observed_vars].to_array().data)
self._obs_vector=self._obs_vector.fillna(0)

if self.in_4d:
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast_4d,
xj.from_xarray(input_state),
all_filtered_padded)
else:
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
xj.from_xarray(input_state),
all_filtered_padded)

all_vals_xr = xr.Dataset(
{var: (('cycle',) + tuple(all_values[var].dims),
all_values[var].data)
for var in all_values.data_vars}
).rename_dims({'time': 'cycle_timestep'})

if return_forecast:
return all_vals_xr.drop_isel(cycle_timestep=-1)
else:
return all_vals_xr.isel(cycle_timestep=0)
Loading
Loading