Skip to content

Commit

Permalink
Merge pull request #410 from astronomy-commons/sean/dropna
Browse files Browse the repository at this point in the history
Wrap nested_dask dropna function
  • Loading branch information
smcguire-cmu authored Sep 12, 2024
2 parents 6efc1b5 + 90e0e72 commit a94b85b
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import pandas as pd
from hipscat.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from hipscat.pixel_math.polygon_filter import SphericalCoordinates
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
Expand Down Expand Up @@ -554,3 +557,27 @@ def join_nested(
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
return Catalog(ddf, ddf_map, hc_catalog)

def dropna(
self,
*,
axis: Axis = 0,
how: AnyAll | lib.NoDefault = no_default,
thresh: int | lib.NoDefault = no_default,
on_nested: bool = False,
subset: IndexLabel | None = None,
ignore_index: bool = False,
) -> Catalog:
catalog = super().dropna(
axis=axis, how=how, thresh=thresh, on_nested=on_nested, subset=subset, ignore_index=ignore_index
)
if self.margin is not None:
catalog.margin = self.margin.dropna(
axis=axis,
how=how,
thresh=thresh,
on_nested=on_nested,
subset=subset,
ignore_index=ignore_index,
)
return catalog
71 changes: 71 additions & 0 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from hipscat.inspection.visualize_catalog import get_projection_method
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.healpix_pixel_function import get_pixel_argsort
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default
from typing_extensions import Self

from lsdb import io
Expand Down Expand Up @@ -424,3 +427,71 @@ def to_hipscat(
**kwargs: Arguments to pass to the parquet write operations
"""
io.to_hipscat(self, base_catalog_path, catalog_name, overwrite, storage_options, **kwargs)

def dropna(
self,
*,
axis: Axis = 0,
how: AnyAll | lib.NoDefault = no_default,
thresh: int | lib.NoDefault = no_default,
on_nested: bool = False,
subset: IndexLabel | None = None,
ignore_index: bool = False,
) -> Self: # type: ignore[name-defined] # noqa: F821:
"""
Remove missing values for one layer of nested columns in the catalog.
Parameters
----------
axis : {0 or 'index', 1 or 'columns'}, default 0
Determine if rows or columns which contain missing values are
removed.
* 0, or 'index' : Drop rows which contain missing values.
* 1, or 'columns' : Drop columns which contain missing value.
Only a single axis is allowed.
how : {'any', 'all'}, default 'any'
Determine if row or column is removed from catalog, when we have
at least one NA or all NA.
* 'any' : If any NA values are present, drop that row or column.
* 'all' : If all values are NA, drop that row or column.
thresh : int, optional
Require that many non-NA values. Cannot be combined with how.
on_nested : str or bool, optional
If not False, applies the call to the nested dataframe in the
column with label equal to the provided string. If specified,
the nested dataframe should align with any columns given in
`subset`.
subset : column label or sequence of labels, optional
Labels along other axis to consider, e.g. if you are dropping rows
these would be a list of columns to include.
Access nested columns using `nested_df.nested_col` (where
`nested_df` refers to a particular nested dataframe and
`nested_col` is a column of that nested dataframe).
ignore_index : bool, default ``False``
If ``True``, the resulting axis will be labeled 0, 1, …, n - 1.
.. versionadded:: 2.0.0
Returns
-------
Catalog
Catalog with NA entries dropped from it.
Notes
-----
Operations that target a particular nested structure return a dataframe
with rows of that particular nested structure affected.
Values for `on_nested` and `subset` should be consistent in pointing
to a single layer, multi-layer operations are not supported at this
time.
"""
ndf = self._ddf.dropna(
axis=axis, how=how, thresh=thresh, on_nested=on_nested, subset=subset, ignore_index=ignore_index
)
return self.__class__(ndf, self._ddf_pixel_map, self.hc_structure)
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def small_sky_order3_source_margin_catalog(test_data_dir):
return lsdb.read_hipscat(test_data_dir / SMALL_SKY_ORDER3_SOURCE_MARGIN_NAME)


@pytest.fixture
def small_sky_with_nested_sources(small_sky_order1_catalog, small_sky_order1_source_with_margin):
return small_sky_order1_catalog.join_nested(
small_sky_order1_source_with_margin, left_on="id", right_on="object_id", nested_column_name="sources"
)


@pytest.fixture
def small_sky_no_metadata_dir(test_data_dir):
return test_data_dir / "raw" / SMALL_SKY_NO_METADATA
Expand Down
37 changes: 37 additions & 0 deletions tests/lsdb/catalog/test_nested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import nested_dask as nd
import numpy as np
import pandas as pd

from lsdb import Catalog


def test_dropna(small_sky_with_nested_sources):
filtered_cat = small_sky_with_nested_sources.query("sources.mag < 15.1")
drop_na_cat = filtered_cat.dropna()
assert isinstance(drop_na_cat, Catalog)
assert isinstance(drop_na_cat._ddf, nd.NestedFrame)
drop_na_compute = drop_na_cat.compute()
filtered_compute = filtered_cat.compute()
assert len(drop_na_compute) < len(filtered_compute)
pd.testing.assert_frame_equal(drop_na_compute, filtered_compute.dropna())


def test_dropna_on_nested(small_sky_with_nested_sources):
def add_na_values_nested(df):
"""replaces the first source_ra value in each nested df with NaN"""
for i in range(len(df)):
first_ra_value = df.iloc[i]["sources"].iloc[0]["source_ra"]
df["sources"].array[i] = df["sources"].array[i].replace(first_ra_value, np.NaN)
return df

filtered_cat = small_sky_with_nested_sources.map_partitions(add_na_values_nested)
drop_na_cat = filtered_cat.dropna(on_nested="sources")
assert isinstance(drop_na_cat, Catalog)
assert isinstance(drop_na_cat._ddf, nd.NestedFrame)
drop_na_sources_compute = drop_na_cat["sources"].compute()
filtered_sources_compute = filtered_cat["sources"].compute()
assert len(drop_na_sources_compute) == len(filtered_sources_compute)
assert sum(map(len, drop_na_sources_compute)) < sum(map(len, filtered_sources_compute))
pd.testing.assert_frame_equal(
drop_na_cat.compute(), filtered_cat._ddf.dropna(on_nested="sources").compute()
)

0 comments on commit a94b85b

Please sign in to comment.