Skip to content

Commit

Permalink
388 timeseries init (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhirtham authored Jul 16, 2021
1 parent d29a2de commit a142185
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

### changes

- `TimeSeries.__init__` accepts `xarray.DataArray` as `data`
parameter [[#429]](https://github.com/BAMWelDX/weldx/pull/429)

### fixes

### documentation
Expand Down
59 changes: 40 additions & 19 deletions weldx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self._units = None
self._interp_counter = 0

if isinstance(data, pint.Quantity):
if isinstance(data, (pint.Quantity, xr.DataArray)):
self._initialize_discrete(data, time, interpolation)
elif isinstance(data, MathematicalExpression):
self._init_expression(data)
Expand Down Expand Up @@ -346,9 +346,25 @@ def __repr__(self):
)
return representation + f"Units:\n\t{self.units}\n"

@staticmethod
def _check_data_array(data_array: xr.DataArray):
"""Raise an exception if the 'DataArray' can't be used as 'self._data'."""
try:
ut.xr_check_coords(data_array, dict(time={"dtype": ["timedelta64[ns]"]}))
except (KeyError, TypeError, ValueError) as e:
raise type(e)(
"The provided 'DataArray' does not match the required pattern. It "
"needs to have a dimension called 'time' with coordinates of type "
"'timedelta64[ns]'. The error reported by the comparison function was:"
f"\n{e}"
)

if not isinstance(data_array.data, pint.Quantity):
raise TypeError("The data of the 'DataArray' must be a 'pint.Quantity'.")

def _initialize_discrete(
self,
data: pint.Quantity,
data: Union[pint.Quantity, xr.DataArray],
time: Union[None, pd.TimedeltaIndex, pint.Quantity],
interpolation: str,
):
Expand All @@ -357,24 +373,29 @@ def _initialize_discrete(
if interpolation is None:
interpolation = "step"

# expand dim for scalar input
data = Q_(data)
if not np.iterable(data):
data = np.expand_dims(data, 0)

# constant value case
if time is None:
time = pd.TimedeltaIndex([0])

if isinstance(time, pint.Quantity):
time = ut.to_pandas_time_index(time)
if not isinstance(time, pd.TimedeltaIndex):
raise ValueError(
'"time" must be a time quantity or a "pandas.TimedeltaIndex".'
)
if isinstance(data, xr.DataArray):
self._check_data_array(data)
data = data.transpose("time", ...)
self._data = data
else:
# expand dim for scalar input
data = Q_(data)
if not np.iterable(data):
data = np.expand_dims(data, 0)

# constant value case
if time is None:
time = pd.TimedeltaIndex([0])

if isinstance(time, pint.Quantity):
time = ut.to_pandas_time_index(time)
if not isinstance(time, pd.TimedeltaIndex):
raise ValueError(
'"time" must be a time quantity or a "pandas.TimedeltaIndex".'
)

dax = xr.DataArray(data=data)
self._data = dax.rename({"dim_0": "time"}).assign_coords({"time": time})
dax = xr.DataArray(data=data)
self._data = dax.rename({"dim_0": "time"}).assign_coords({"time": time})
self.interpolation = interpolation

def _init_expression(self, data):
Expand Down
25 changes: 25 additions & 0 deletions weldx/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pint
import pytest
import xarray as xr

import weldx.util as ut
from weldx.constants import WELDX_QUANTITY as Q_
Expand Down Expand Up @@ -292,6 +293,30 @@ def test_construction_expression(data, shape_exp, unit_exp):
assert ts.data_array is None
assert Q_(1, unit_exp).check(UREG.get_dimensionality(ts.units))

# test_init_data_array -------------------------------------------------------------

@staticmethod
@pytest.mark.parametrize(
"data, dims, coords, exception_type",
[
(Q_([1, 2, 3], "m"), "time", dict(time=TDI([1, 2, 3])), None),
(Q_([1, 2, 3], "m"), "a", dict(a=TDI([1, 2, 3])), KeyError),
(Q_([[1, 2]], "m"), ("a", "time"), dict(a=[2], time=TDI([1, 2])), None),
(Q_([1, 2, 3], "m"), "time", None, KeyError),
(Q_([1, 2, 3], "m"), "time", dict(time=[1, 2, 3]), TypeError),
([1, 2, 3], "time", dict(time=TDI([1, 2, 3])), TypeError),
],
)
def test_init_data_array(data, dims, coords, exception_type):
"""Test the `__init__` method with an xarray as data parameter."""
da = xr.DataArray(data=data, dims=dims, coords=coords)
if exception_type is not None:
with pytest.raises(exception_type):
TimeSeries(da)
else:
ts = TimeSeries(da)
assert ts.data_array.dims[0] == "time"

# test_construction_exceptions -----------------------------------------------------

values_def = Q_([5, 7, 3, 6, 8], "m")
Expand Down

0 comments on commit a142185

Please sign in to comment.