Skip to content

Commit

Permalink
override Dataset __setattr__ function
Browse files Browse the repository at this point in the history
With this commit, it is possible to do Dataset.name = value instead of
the usual Dataset['name'] = (coords, value).

There may be a good reason that this is currently blocked by xarray,
but none immediately came to mind so I thought this would be worth
trying.

Another alternative to this is to do Dataset.name.data = value if your
'name' DataArray was called 'data'. This is already possible with the
default xarray Dataset object.
  • Loading branch information
taylorbell57 committed Apr 27, 2022
1 parent c04defe commit 55654db
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ dmypy.json

# Pyre type checker
.pyre/

# VS Code workspaces
*.code-workspace
23 changes: 23 additions & 0 deletions src/astraeus/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any
import xarray as xr


class Dataset(xr.Dataset):
# Need to define this to avoid warning messages
__slots__ = ("_dataset",)

def __setattr__(self, name: str, value: Any) -> None:
"""Overwrite the setattr function to allow behaviour like Dataset.name = value instead of the usual Dataset['name'] = (coords, value).
"""
try:
object.__setattr__(self, name, value)
except AttributeError as e:
try:
if str(e) != "{!r} object has no attribute {!r}".format(
type(self).__name__, name
):
raise e
else:
self[name] = (list(self.coords), value)
except Exception as e2:
raise e2
5 changes: 3 additions & 2 deletions src/astraeus/xarrayIO.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import xarray as xr
from .dataset import Dataset


def writeXR(filename, ds, verbose=True, append=False):
Expand Down Expand Up @@ -67,7 +68,7 @@ def readXR(filename, verbose=True):
(filename.endswith(".h5") == False) and \
(filename.endswith(".nc") == False):
filename += ".h5"
ds = xr.open_dataset(filename, engine='h5netcdf')
ds = Dataset(xr.open_dataset(filename, engine='h5netcdf'))
if verbose:
print(f"Finished loading parameters from {filename}")
except Exception as e:
Expand Down Expand Up @@ -257,7 +258,7 @@ def makeDataset(dictionary=None):
ds: object
Xarray Dataset
"""
ds = xr.Dataset(dictionary)
ds = Dataset(dictionary)
return ds

def concat(datasets, dim='time', data_vars='minimal', coords='minimal', compat='override'):
Expand Down

0 comments on commit 55654db

Please sign in to comment.