Skip to content

Commit

Permalink
Dev/xarray base (#56)
Browse files Browse the repository at this point in the history
* Data().generate() now returns an Xarray 'state vector'

* Remove vector class import for now, replacing vector objects with simple xarrays

* Update sqgturb to work with xarray. Returns in real space, not spectral.

* Updated observer with xarray object, but only basic functionality is working

* 3DVar working with xarray:

* ETKF working with xarray, initial commit

* Sort observed locations

* Data times are stored as base numpy arrays, since xarray coords cannot be jax

* Reinserting netcdf utils into data class

* gcp with xarray

* ETKF can handle irregular obs now, but doesn't have proper indices info

* Use system_index variable for H

* Observer adds integer indices for easier H calculation

* Allow datavars to not match observed vars

* Updated system index and remove sort from the observer

* Making some dacycler methods part of parent class for ease of maintenance

* Initial version of xarray var4dBP, needs testing and cleaning

* Cleaned var4dbp using xarray

* All dacyclers working with xarray, but possibly some accuracy issues with 4dvar and 4dvarBP

* State Vec has delta_t attribute and M is provided as xarray

* Working RC Model with xarray

* Fixed generator extra step rounding error

* Observer can accept random_time_density now

* Add permissible xarray jax to pyproject toml

* Remove all vector module imports

* Rename i to index for toy data generators

* GCP system dim now specified

* Observer can accept list of error_sds, and now samples with replacement when there is more than 1 dimension to sample along

* Fixed issue with missing time offset for 4dvar and 4dvarBP

* Reassign coords to match input state within dacycler

* Apply coord reassing for 4dvar cycler too

* Remove unnecessary print from 3dvar

* Reassign coords for outer loop carry instead of dropping time

* Updated gcp to properly assign system_dim

* Assign system_dim as attr, not coord

* Fix typo for xarray_jax git repo

* XArray accessors for helper methods

* Update xarray accessor methods: ds.dab.flatten() and da.dab.unflatten()

* NeuralGCM model with configuration YAML

* Neuralgcm forecast returns last step and full forecast tuple

* Remove old import_xarray

* Add date and variable filtering to load netcf

* Add option for data split in fraction

* Updated load netcdf with filtering and data base test

* Updated LE calcs for xarray

* QGS with xarray output

* All off-line data tests working with xarray

* Partially working DA tests (3dVar working)

* Store error sd in obs vec

* ETKF and var4dBP tests passing

* Updated vals for 4dvar testing (baed on previous stable vrsion, were out of date)

* Updated presaved 4dvar test vals with consistent model delta_t (previously was 0.01 for nature run, 0.05 for forecast model)

* 4dvar method need different calc_default_R method (single observation timestep instead of full flattened observations)

* Test without specifying R

* Updated enso_indices (and tests) for xarray

* Removing vector class tests

* Updated gcp tests

* Smaller GCP tests to speed up downloads

* Smaller GCP tests, all passing

* Updated model base test for xarray, passing

* Observer tests updated and sqgturb system dim set in real space

* Skipping obsop tests (temporary, not used in any examples)

* Update README with xarray, added observer example

* Correct system_dim in sqgturb test
  • Loading branch information
kysolvik authored Oct 15, 2024
1 parent 4790660 commit 0ec0b80
Show file tree
Hide file tree
Showing 34 changed files with 1,242 additions and 2,185 deletions.
29 changes: 23 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,28 @@ All of the data objects are set up with reasonable defaults. Generating data is

```python
l96_obj = dab.data.Lorenz96() # Create data generator object
l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data
l96_obj.values # View the output values
ds_l96 = l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data as Xarray Dataset
ds_l96.dab.flatten().values # View output values flattened along time dimension
```
This example is for a Lorenz96 model, but all of the data objects work in a similar way.

#### Sampling observations

Now that we have a generated dataset, we can easily generate noisy observations from it like so:

```python
obs = dab.observer.Observer(
ds_l96, # Our generated Dataset object
random_time_density=0.4, # Randomly sampling at ~40% of times
random_location_density=0.3, # Randomly sample ~30% of variables
# random_location_count = 10, # Alternatively, can specify number of locations to sample
error_sd=1.2 # Add Gaussian Noise with SD = 1.2
)
obs_vec = obs.observe() # Run observe() method to generate observations
obs_vec
```

The Observer class is very flexible, allowing users to provide specific times and locations or randomly generate them. You can also choose to use "stationary" or "nonstationary" observers, indicating whether to sample the same locations at each observation time step or to sample different ones (default is "stationary").

#### Customizing generation options

Expand All @@ -94,8 +111,8 @@ l96_options = {'forcing_term': 7.5,
'system_dim': 5,
'delta_t': 0.05}
l96_obj = dab.data.Lorenz96(**l96_options) # Create data generator object
l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data
l96_obj.values # View the output values
ds_l96 = l96_obj.generate(n_steps=1000) # Generate Lorenz96 simulation data
ds_l96 # View the output values
```

- For example, for the Google Cloud (GCP) ERA5 data-downloader, we can select our variables and time period like this:
Expand All @@ -105,6 +122,6 @@ gcp_options = {'variables': ['2m_temperature', 'sea_surface_temperature'],
'date_start': '2020-06-01'
'date_end': '2020-06-07'}
gcp_obj = dab.data.GCP(**gcp_options) # Create data generator object
gcp_obj.load() # Loads data. Can also use gcp_obj.generate()
gcp_obj.values # View the output values
ds_gcp = gcp_obj.load() # Loads data. Can also use gcp_obj.generate()
ds_gcp # View the output values
```
2 changes: 1 addition & 1 deletion dabench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import data, vector, model, observer, obsop, dacycler, _suppl_data
from . import data, model, observer, obsop, dacycler, _suppl_data
230 changes: 184 additions & 46 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Base class for Data Assimilation Cycler object (DACycler)"""

from dabench import vector
import numpy as np
import jax.numpy as jnp
import jax
import xarray as xr
import xarray_jax as xj

import dabench.dacycler._utils as dac_utils

class DACycler():
"""Base class for DACycler object
Expand Down Expand Up @@ -37,6 +41,7 @@ def __init__(self,
R=None,
H=None,
h=None,
analysis_time_in_window=None
):

self.h = h
Expand All @@ -48,15 +53,122 @@ def __init__(self,
self.system_dim = system_dim
self.delta_t = delta_t
self.model_obj = model_obj
self.analysis_time_in_window = analysis_time_in_window


def _calc_default_H(self, obs_values, obs_loc_indices):
H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim))
H = H.at[jnp.arange(H.shape[0]),
obs_loc_indices.flatten(),
].set(1)
return H

def _calc_default_R(self, obs_values, obs_error_sd):
return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2)

def _calc_default_B(self):
"""If B is not provided, identity matrix with shape (system_dim, system_dim."""
return jnp.identity(self.system_dim)

def _step_forecast(self, xa, n_steps=1):
"""Perform forecast using model object"""
return self.model_obj.forecast(xa, n_steps=n_steps)

def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask,
H=None, h=None, R=None, B=None, **kwargs):
if H is not None or h is None:
vals = self._cycle_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, H, R, B, **kwargs)
return vals
else:
raise ValueError(
'Only linear obs operators (H) are supported right now.')
vals = self._cycle_general_obsop(
xb, obs_vals, obs_locs, obs_time_mask,
obs_loc_mask, h, R, B, **kwargs)
return vals

def _cycle_and_forecast(self, cur_state, filtered_idx):
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
cur_time = cur_state['_cur_time'].data
cur_state = cur_state.drop_vars(['_cur_time'])
obs_time_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

# 2. Calculate analysis
cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get()
cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get()
cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool)
cur_obs_time_mask = jnp.repeat(obs_time_mask, cur_obs_vals.shape[-1])
analysis = self._step_cycle(
cur_state,
cur_obs_vals,
cur_obs_loc_indices,
obs_loc_mask=cur_obs_loc_mask,
obs_time_mask=cur_obs_time_mask
)
# 3. Forecast next timestep
next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window)
next_state = next_state.assign(
_cur_time = cur_time + self.analysis_window
).assign_coords(
cur_state.coords)

return xj.from_xarray(next_state), forecast_states

def _cycle_and_forecast_4d(self, cur_state, filtered_idx):
# 1. Get data
# 1-b. Calculate obs_time_mask and restore filtered_idx to original values
cur_state = cur_state.to_xarray()
cur_time = cur_state['_cur_time'].data
cur_state = cur_state.drop_vars(['_cur_time'])
obs_time_mask = filtered_idx > 0
filtered_idx = filtered_idx - 1

cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get()
cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get()
cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1)
cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1)

# Calculate obs window indices: closest model timesteps that match obs
obs_window_indices =jnp.array([
jnp.argmin(
jnp.abs(obs_time - (cur_time + self._model_timesteps))
) for obs_time in cur_obs_times
])

# 2. Calculate analysis
analysis = self._step_cycle(
cur_state,
cur_obs_vals,
cur_obs_loc_indices,
obs_loc_mask=cur_obs_loc_mask,
obs_time_mask=obs_time_mask,
obs_window_indices=obs_window_indices
)

# 3. Forecast forward
next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window)
next_state = next_state.assign(
_cur_time = cur_time + self.analysis_window
).assign_coords(
cur_state.coords)

return xj.from_xarray(next_state), forecast_states

def cycle(self,
input_state,
start_time,
obs_vector,
n_cycles,
analysis_window,
obs_error_sd=None,
analysis_window=0.2,
analysis_time_in_window=None,
return_forecast=False):
return_forecast=False
):
"""Perform DA cycle repeatedly, including analysis and forecast
Args:
Expand All @@ -79,52 +191,78 @@ def cycle(self,
vector.StateVector of analyses and times.
"""

# These could be different if observer doesn't observe all variables
# For now, making them the same
self._observed_vars = obs_vector['variable'].values
self._data_vars = list(input_state.data_vars)

if obs_error_sd is None:
obs_error_sd = obs_vector.error_sd

self.analysis_window = analysis_window

# If don't specify analysis_time_in_window, is assumed to be middle
if analysis_time_in_window is None:
analysis_time_in_window = analysis_window/2
if self.analysis_time_in_window is None and analysis_time_in_window is None:
analysis_time_in_window = self.analysis_window/2
else:
analysis_time_in_window = self.analysis_time_in_window

# Steps per window + 1 to include start
self.steps_per_window = round(analysis_window/self.delta_t) + 1
self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t

# Time offset from middle of time window, for gathering observations
_time_offset = (analysis_window/2) - analysis_time_in_window

# Number of model steps to run per window
steps_per_window = round(analysis_window/self.delta_t) + 1

# For storing outputs
all_output_states = []
all_times = []
cur_time = start_time
cur_state = input_state

for i in range(n_cycles):
# 1. Filter observations to inside analysis window
window_middle = cur_time + _time_offset
window_start = window_middle - analysis_window/2
window_end = window_middle + analysis_window/2
obs_vec_timefilt = obs_vector.filter_times(
window_start, window_end
)

if obs_vec_timefilt.values.shape[0] > 0:
# 2. Calculate analysis
analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt)
# 3. Forecast through analysis window
forecast_states = self._step_forecast(analysis,
n_steps=steps_per_window)
# 4. Save outputs
if return_forecast:
# Append forecast to current state, excluding last step
all_output_states.append(forecast_states.values[:-1])
all_times.append(
np.arange(steps_per_window-1)*self.delta_t + cur_time
)
else:
all_output_states.append(analysis.values[np.newaxis])
all_times.append([cur_time])

# Starting point for next cycle is last step of forecast
cur_state = forecast_states[-1]
cur_time += analysis_window

return vector.StateVector(values=np.concatenate(all_output_states),
times=np.concatenate(all_times))
# Set up for jax.lax.scan, which is very fast
all_times = dac_utils._get_all_times(
start_time,
analysis_window,
n_cycles)


if self.steps_per_window is None:
self.steps_per_window = round(analysis_window/self.delta_t) + 1
self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t
# Get the obs vectors for each analysis window
all_filtered_idx = dac_utils._get_obs_indices(
obs_times=jnp.array(obs_vector.time.values),
analysis_times=all_times+_time_offset,
start_inclusive=True,
end_inclusive=self.in_4d,
analysis_window=analysis_window
)
input_state = input_state.assign(_cur_time=start_time)

all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True)
self._obs_vector=obs_vector
self.obs_error_sd = obs_error_sd
if obs_vector.stationary_observers:
self._obs_loc_masks = jnp.ones(
obs_vector[self._observed_vars].to_array().shape, dtype=bool)
else:
self._obs_loc_masks = ~np.isnan(
obs_vector[self._observed_vars].to_array().data)
self._obs_vector=self._obs_vector.fillna(0)

if self.in_4d:
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast_4d,
xj.from_xarray(input_state),
all_filtered_padded)
else:
cur_state, all_values = jax.lax.scan(
self._cycle_and_forecast,
xj.from_xarray(input_state),
all_filtered_padded)

all_vals_xr = xr.Dataset(
{var: (('cycle',) + tuple(all_values[var].dims),
all_values[var].data)
for var in all_values.data_vars}
).rename_dims({'time': 'cycle_timestep'})

if return_forecast:
return all_vals_xr.drop_isel(cycle_timestep=-1)
else:
return all_vals_xr.isel(cycle_timestep=0)
Loading

0 comments on commit 0ec0b80

Please sign in to comment.