Skip to content

Commit

Permalink
Add Support for Sorting Lightcurves by Time (#353)
Browse files Browse the repository at this point in the history
* Initial commit to sort batch by time

* lint fix

* Sort partitions by time

* Fix merge and linting

* Break lightcurve sorting out of batch

* Update comments

* Adds a by_band param to sort_ligthcurves

* Fix merge err with parquet_ensemble_without_client

* Lint fix
  • Loading branch information
wilsonbb authored Mar 19, 2024
1 parent 0c7a044 commit 86a0c45
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,46 @@ def check_lightcurve_cohesion(self):
return False
return True

def sort_lightcurves(self, by_band=True):
"""Sorts each Source partition first by the indexed ID column and then by
the time column, each in ascending order.
This allows for efficient access of lightcurves by their indexed object ID
while still giving easy access to the sorted time series.
Note that if the lightcurves are split across multiple partitions, this operation
only sorts on a per-partition basis, and the table will not be globally sorted.
You can check that no lightcurves are not split across multiple partitions by
seeing if `Ensemble.check_lightcurve_cohesion()` is `True`.
Parameters
----------
by_band: `bool`, optional
If True, the lightcurves are still sorted first by the indexed ID column,
but then by band and then by timestamp, all in ascending order.
Returns
-------
Ensemble
"""
self._lazy_sync_tables(table="source")

# Dask lacks support for multi-column sorting and indices, but if we have
# lightcurve cohesion, we can sort each partition individually since
# each lightcurve should only be in a single partition. We sort the Source
# table first by its indexed ID column and then by the timestamp.
id_col, time_col = self._id_col, self._time_col # save column names for scoping for the lambda
if not by_band:
self.source.map_partitions(lambda x: x.sort_values([id_col, time_col])).update_ensemble()
else:
band_col = self._band_col
self.source.map_partitions(
lambda x: x.sort_values([id_col, band_col, time_col])
).update_ensemble()

return self

def compute(self, table=None, **kwargs):
"""Wrapper for dask.dataframe.DataFrame.compute()
Expand Down Expand Up @@ -1001,7 +1041,17 @@ def bin_sources(
self.source.set_dirty(True)
return self

def batch(self, func, *args, meta=None, by_band=False, use_map=True, on=None, label="", **kwargs):
def batch(
self,
func,
*args,
meta=None,
by_band=False,
use_map=True,
on=None,
label="",
**kwargs,
):
"""Run a function from tape.TimeSeries on the available ids
Parameters
Expand Down
89 changes: 89 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,95 @@ def test_batch(data_fixture, request, use_map, on):
assert pytest.approx(result.values[1]["r"], 0.001) == -0.49639028


@pytest.mark.parametrize(
"data_fixture",
[
"parquet_ensemble",
"parquet_ensemble_with_divisions",
],
)
@pytest.mark.parametrize("sort_by_band", [True, False])
def test_sort_lightcurves(data_fixture, request, sort_by_band):
"""
Test that we can have the ensemble sort its lightcurves by timestamp.
"""
parquet_ensemble = request.getfixturevalue(data_fixture)

# filter NaNs from the source table
parquet_ensemble = parquet_ensemble.prune(10).dropna(table="source")

# To check that all columns are rearranged when sorting the time column,
# we create a duplicate time column which should be sorted as well.
parquet_ensemble.source.assign(
dup_time=parquet_ensemble.source[parquet_ensemble._time_col]
).update_ensemble()

# Validate the Ensemble is sorted by ID
assert parquet_ensemble.check_sorted("source")

bands = parquet_ensemble.source[parquet_ensemble._band_col].unique().compute()

# A trivial function that raises an Exception if the data is not temporally sorted
def my_mean(flux, time, dup_time, band):
if not sort_by_band:
# Check that the time column is sorted
if not np.all(time[:-1] <= time[1:]):
raise ValueError("The time column was not sorted in ascending order")
else:
# Check that the band column is sorted
if not np.all(band[:-1] <= band[1:]):
raise ValueError("The bands column was not sorted in ascending order")
# Check that the time column is sorted for each band
for curr_band in bands:
# Get a mask for the current band
mask = band == curr_band
if not np.all(time[mask][:-1] <= time[mask][1:]):
raise ValueError(f"The time column was not sorted in ascending order for band {band}")
# Check that the other columns were rearranged to preserve the dataframe's rows
# We can use the duplicate time column as an easy check.
if not np.array_equal(time, dup_time):
raise ValueError("The dataframe's time column was sorted but isn't aligned with other columns")
return np.mean(flux)

band = parquet_ensemble._band_col if sort_by_band else None

# Validate that our custom function throws an Exception on the unsorted data to
# ensure that we actually sort when requested.
with pytest.raises(ValueError):
parquet_ensemble.batch(
my_mean,
parquet_ensemble._flux_col,
parquet_ensemble._time_col,
"dup_time",
parquet_ensemble._band_col,
by_band=False,
).compute()

parquet_ensemble.sort_lightcurves(by_band=sort_by_band)

result = parquet_ensemble.batch(
my_mean,
parquet_ensemble._flux_col,
parquet_ensemble._time_col,
"dup_time",
parquet_ensemble._band_col,
by_band=False,
)

# Validate that the result is non-empty
assert len(result.compute()) > 0

# Make sure that divisions information was propagated if known
if parquet_ensemble.source.known_divisions and parquet_ensemble.object.known_divisions:
assert result.known_divisions

# Check that the dataframe is still sorted by the ID column
assert parquet_ensemble.check_sorted("source")

# Verify that we preserved lightcurve cohesion
assert parquet_ensemble.check_lightcurve_cohesion()


@pytest.mark.parametrize("on", [None, ["ps1_objid", "filterName"], ["filterName", "ps1_objid"]])
@pytest.mark.parametrize("func_label", ["mean", "bounds"])
def test_batch_by_band(parquet_ensemble, func_label, on):
Expand Down

0 comments on commit 86a0c45

Please sign in to comment.