Skip to content

Commit

Permalink
Dacycler utils with type hints and updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Oct 24, 2024
1 parent 0fba72c commit e7374d9
Showing 1 changed file with 63 additions and 47 deletions.
110 changes: 63 additions & 47 deletions dabench/dacycler/_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
"""Utils for data assimilation cyclers"""

import jax.numpy as jnp
import jax
import numpy as np
import xarray as xr
import xarray_jax as xj


# For typing
ArrayLike = list | np.ndarray | jax.Array
XarrayDatasetLike = xr.Dataset | xj.XjDataset

def _get_all_times(
start_time,
analysis_window,
analysis_cycles,
):
start_time: float,
analysis_window: float,
analysis_cycles: int
) -> jax.Array:
"""Calculate times of the centers of all analysis windows.
Args:
start_time (float): Start time of DA experiment in model time units.
analysis_window (float): Length of analysis window, in model time
start_time: Start time of DA experiment in model time units.
analysis_window: Length of analysis window, in model time
units.
analysis_cycles (int): Number of analysis cycles to perform.
analysis_cycles: Number of analysis cycles to perform.
Returns:
array of all analysis window center-times.
Array of all analysis window center-times.
"""
Expand All @@ -32,26 +39,26 @@ def _get_all_times(


def _get_obs_indices(
analysis_times,
obs_times,
analysis_window,
start_inclusive=True,
end_inclusive=False
):
analysis_times: ArrayLike,
obs_times: ArrayLike,
analysis_window: float,
start_inclusive: bool = True,
end_inclusive: bool = False
) -> list:
"""Get indices of obs times for each analysis cycle to pass to jax.lax.scan
Args:
analysis_times (list): List of times for all analysis window, centered
analysis_times: List of times for all analysis window, centered
in middle of time window. Output of _get_all_times().
obs_times (list): List of times for all observations.
analysis_window (float): Length of analysis window.
start_inclusive (bool): Include obs times equal to beginning of
obs_times: List of times for all observations.
analysis_window: Length of analysis window.
start_inclusive: Include obs times equal to beginning of
analysis window. Default is True
end_inclusive (bool): Include obs times equal to end of
end_inclusive: Include obs times equal to end of
analysis window. Default is False.
Returns:
list with each element containing array of obs indices for the
List with each element containing array of obs indices for the
corresponding analysis cycle.
"""
# Get the obs vectors for each analysis window
Expand All @@ -77,60 +84,69 @@ def _get_obs_indices(
return all_filtered_idx


def _time_resize(
row: ArrayLike,
size: int,
add_one: bool
) -> np.ndarray:
new = np.array(row) + add_one
new.resize(size)
return new


def _pad_time_indices(
obs_indices,
add_one=True
):
obs_indices: ArrayLike,
add_one: bool = True
) -> ArrayLike:
"""Pad observation indices for each analysis window.
Args:
obs_indices (list): List of arrays where each array contains
obs_indices: List of arrays where each array contains
obs indices for an analysis cycle. Result of _get_obs_indices.
add_one (bool): If True, will add one to all index values to encode
add_one: If True, will add one to all index values to encode
indices to be masked out for DA (i.e. zeros represent indices to
be masked out). Default is True.
Returns:
padded_indices (array): Array of padded obs_indices, with shape:
padded_indices: Array of padded obs_indices, with shape:
(num_analysis_cycles, max_obs_per_cycle).
"""

def resize(row, size, add_one):
new = np.array(row) + add_one
new.resize(size)
return new

# find longest row length
row_length = max(obs_indices, key=len).__len__()
padded_indices = np.array([resize(row, row_length, add_one) for row in obs_indices])
padded_indices = np.array([_time_resize(row, row_length, add_one)
for row in obs_indices])

return padded_indices


def _pad_obs_locs(obs_vec):
def _obs_resize(
row: ArrayLike,
size: float
) -> np.ndarray:
new_vals_locs = np.array(np.stack(row), order='F')
new_vals_locs.resize((new_vals_locs.shape[0], size))
mask = np.ones_like(new_vals_locs[0]).astype(int)
if size > len(row[0]):
mask[-(size-len(row[0])):] = 0
return np.vstack([new_vals_locs, mask]).T


def _pad_obs_locs(
obs_vec: XarrayDatasetLike
) -> tuple[ArrayLike, ArrayLike, ArrayLike]:
"""Pad observation location indices to equal spacing
Args:
obs_vec (dabench.vector.ObsVector): Observation vector
object containing times, locations, and values of obs.
obs_vec: Xarray containing times, locations, and values of obs.
Returns:
(vals, locs, masks): Tuple containing padded arrays of obs
Tuple containing padded arrays of obs
values and locations, and binary array masks where 1 is
a valid observation value/location and 0 is not.
"""

def resize(row, size):
new_vals_locs = np.array(np.stack(row), order='F')
new_vals_locs.resize((new_vals_locs.shape[0], size))
mask = np.ones_like(new_vals_locs[0]).astype(int)
if size > len(row[0]):
mask[-(size-len(row[0])):] = 0
return np.vstack([new_vals_locs, mask]).T

# Find longest row length
row_length = max(obs_vec.values, key=len).__len__()
padded_arrays_masks = np.array([resize(row, row_length) for row in
padded_arrays_masks = np.array([_obs_resize(row, row_length) for row in
np.stack([obs_vec.values,
obs_vec.location_indices],
axis=1)], dtype=float)
Expand Down

0 comments on commit e7374d9

Please sign in to comment.