From 55654db55ba50890fee8c7f92410da0966cb14b2 Mon Sep 17 00:00:00 2001 From: Taylor Bell Date: Wed, 27 Apr 2022 11:34:56 -0700 Subject: [PATCH] override Dataset __setattr__ function With this commit, it is possible to do Dataset.name = value instead of the usual Dataset['name'] = (coords, value). There may be a good reason that this is currently blocked by xarray, but none immediately came to mind so I thought this would be worth trying. Another alternative to this is to do Dataset.name.data = value if your 'name' DataArray was called 'data'. This is already possible with the default xarray Dataset object. --- .gitignore | 3 +++ src/astraeus/dataset.py | 23 +++++++++++++++++++++++ src/astraeus/xarrayIO.py | 5 +++-- 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 src/astraeus/dataset.py diff --git a/.gitignore b/.gitignore index 4397233..a601e39 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,6 @@ dmypy.json # Pyre type checker .pyre/ + +# VS Code workspaces +*.code-workspace diff --git a/src/astraeus/dataset.py b/src/astraeus/dataset.py new file mode 100644 index 0000000..ad1754e --- /dev/null +++ b/src/astraeus/dataset.py @@ -0,0 +1,23 @@ +from typing import Any +import xarray as xr + + +class Dataset(xr.Dataset): + # Need to define this to avoid warning messages + __slots__ = ("_dataset",) + + def __setattr__(self, name: str, value: Any) -> None: + """Overwrite the setattr function to allow behaviour like Dataset.name = value instead of the usual Dataset['name'] = (coords, value). + """ + try: + object.__setattr__(self, name, value) + except AttributeError as e: + try: + if str(e) != "{!r} object has no attribute {!r}".format( + type(self).__name__, name + ): + raise e + else: + self[name] = (list(self.coords), value) + except Exception as e2: + raise e2 diff --git a/src/astraeus/xarrayIO.py b/src/astraeus/xarrayIO.py index fb3c0d6..5f93016 100644 --- a/src/astraeus/xarrayIO.py +++ b/src/astraeus/xarrayIO.py @@ -1,5 +1,6 @@ import numpy as np import xarray as xr +from .dataset import Dataset def writeXR(filename, ds, verbose=True, append=False): @@ -67,7 +68,7 @@ def readXR(filename, verbose=True): (filename.endswith(".h5") == False) and \ (filename.endswith(".nc") == False): filename += ".h5" - ds = xr.open_dataset(filename, engine='h5netcdf') + ds = Dataset(xr.open_dataset(filename, engine='h5netcdf')) if verbose: print(f"Finished loading parameters from {filename}") except Exception as e: @@ -257,7 +258,7 @@ def makeDataset(dictionary=None): ds: object Xarray Dataset """ - ds = xr.Dataset(dictionary) + ds = Dataset(dictionary) return ds def concat(datasets, dim='time', data_vars='minimal', coords='minimal', compat='override'):