Skip to content

Commit

Permalink
Update forecast() to use xarray input/output
Browse files Browse the repository at this point in the history
  • Loading branch information
kysolvik committed Oct 3, 2024
1 parent f939e9e commit 77b369d
Showing 1 changed file with 2 additions and 21 deletions.
23 changes: 2 additions & 21 deletions dabench/model/_neuralgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,24 +330,9 @@ def regrid_input(self, data, fill_nans=False):

return eval_data

def flat_to_xarray(self, flat, xr_template):
remap_dict = {}
coords_order = ['time','level','longitude','latitude']
for data_var in xr_template.data_vars:
remap_dict[data_var] = (coords_order,
(flat[self.flat_vars_indices[data_var]]).reshape(xr_template[data_var].shape))
return xr_template.update(remap_dict)

def xarray_to_flat(self, xr):
# Dim order before flattening goes: variable, level, latitude, longitude
xr_numpy = xr.transpose('time','level','latitude','longitude')[self.data_var_order].to_array().to_numpy()
return xr_numpy.flatten()

def forecast(self, state_vec, n_steps):
# Template forecast method to interface with DA
input_modelstate = self._model.inputs_from_xarray(
self.flat_to_xarray(state_vec.values, self.ics_eval.head(time=1)).isel(time=0)
)
input_modelstate = self._model.inputs_from_xarray(state_vec)
encoded = self._model.encode(input_modelstate, self.input_forcings_t0)
final_state, predictions = self._model.unroll(
encoded,
Expand All @@ -360,11 +345,7 @@ def forecast(self, state_vec, n_steps):
predictions,
times=self._model.sim_time_to_datetime64(predictions['sim_time'])
)
out_statevec = vector.StateVector(
values=self.xarray_to_flat(preds_xarray.drop_vars('sim_time')),
store_as_jax=True
)
return out_statevec
return preds_xarray

def postprocess_helper(self, out_state, forcings):
decoded = self._model.decode(out_state, forcings)
Expand Down

0 comments on commit 77b369d

Please sign in to comment.