Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper for catalog merge #48

Merged
merged 4 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,57 @@ def cone_search(self, ra: float, dec: float, radius: float):
cone_search_ddf = cast(dd.DataFrame, cone_search_ddf)
ddf_partition_map = {pixel: i for i, pixel in enumerate(pixels_in_cone)}
return Catalog(cone_search_ddf, ddf_partition_map, filtered_hc_structure)

def merge(
self,
other: Catalog,
how: str = "inner",
on: str | List | None = None,
left_on: str | List | None = None,
right_on: str | List | None = None,
left_index: bool = False,
right_index: bool = False,
suffixes: Tuple[str, str] | None = None,
) -> dd.DataFrame:
"""Performs the merge of two catalog Dataframes

More information about pandas merge is available
`here <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.merge.html>`__.

Args:
other (Catalog): The right catalog to merge with.
how (str): How to handle the merge of the two catalogs.
One of {'left', 'right', 'outer', 'inner'}, defaults to 'inner'.
on (str | List): Column or index names to join on. Defaults to the
intersection of columns in both Dataframes if on is None and not
merging on indexes.
left_on (str | List): Column to join on the left Dataframe. Lists are
supported if their length is one.
right_on (str | List): Column to join on the right Dataframe. Lists are
supported if their length is one.
left_index (bool): Use the index of the left Dataframe as the join key.
Defaults to False.
right_index (bool): Use the index of the right Dataframe as the join key.
Defaults to False.
suffixes (Tuple[str, str]): A pair of suffixes to be appended to the
end of each column name when they are joined. Defaults to using the
name of the catalog for the suffix.

Returns:
A new Dask Dataframe containing the data points that result from the merge
of the two catalogs.
"""
if suffixes is None:
suffixes = (f"_{self.name}", f"_{other.name}")
if len(suffixes) != 2:
raise ValueError("`suffixes` must be a tuple with two strings")
return self._ddf.merge(
other._ddf,
how=how,
on=on,
left_on=left_on,
right_on=right_on,
left_index=left_index,
right_index=right_index,
suffixes=suffixes,
)
78 changes: 78 additions & 0 deletions tests/lsdb/catalog/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import dask.dataframe as dd
import pandas as pd
import pytest


@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"])
def test_catalog_merge_on_indices(small_sky_catalog, small_sky_order1_catalog, how):
kwargs = {
"how": how,
"left_index": True,
"right_index": True,
"suffixes": ("_left", "_right")
}
# Setting the object "id" for index on both catalogs
small_sky_catalog._ddf = small_sky_catalog._ddf.set_index("id")
small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.set_index("id")
# The wrapper outputs the same result as the underlying pandas merge
merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs)
assert isinstance(merged_ddf, dd.DataFrame)
expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs)
pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute())


@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"])
def test_catalog_merge_on_columns(small_sky_catalog, small_sky_order1_catalog, how):
kwargs = {
"how": how,
"on": "id",
"suffixes": ("_left", "_right")
}
# Make sure none of the test catalogs have "id" for index
small_sky_catalog._ddf = small_sky_catalog._ddf.reset_index()
small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.reset_index()
# The wrapper outputs the same result as the underlying pandas merge
merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs)
assert isinstance(merged_ddf, dd.DataFrame)
expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs)
pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute())


@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"])
def test_catalog_merge_on_index_and_column(small_sky_catalog, small_sky_order1_catalog, how):
kwargs = {
"how": how,
"left_index": True,
"right_on": "id",
"suffixes": ("_left", "_right")
}
# Setting the object "id" for index on the left catalog
small_sky_catalog._ddf = small_sky_catalog._ddf.set_index("id")
# Make sure the right catalog does not have "id" for index
small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.reset_index()
# The wrapper outputs the same result as the underlying pandas merge
merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs)
assert isinstance(merged_ddf, dd.DataFrame)
expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs)
pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute())


def test_catalog_merge_invalid_suffixes(small_sky_catalog, small_sky_order1_catalog):
with pytest.raises(ValueError, match="`suffixes` must be a tuple with two strings"):
small_sky_catalog.merge(
small_sky_order1_catalog, how="inner", on="id", suffixes=("_left", "_middle", "_right")
)


def test_catalog_merge_no_suffixes(small_sky_catalog, small_sky_order1_catalog):
merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, how="inner", on="id")
assert isinstance(merged_ddf, dd.DataFrame)
# Get the columns with the same name in both catalogs
non_join_columns_left = small_sky_catalog._ddf.columns.drop("id")
non_join_columns_right = small_sky_order1_catalog._ddf.columns.drop("id")
intersected_cols = list(set(non_join_columns_left) & set(non_join_columns_right))
# The suffixes of these columns in the dataframe include the catalog names
suffixes = [f"_{small_sky_catalog.name}", f"_{small_sky_order1_catalog.name}"]
for column in intersected_cols:
for suffix in suffixes:
assert f"{column}{suffix}" in merged_ddf.columns