diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index aa3f339f..2aa86be5 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -247,6 +247,60 @@ def cone_search(self, ra: float, dec: float, radius: float): ddf_partition_map = {pixel: i for i, pixel in enumerate(pixels_in_cone)} return Catalog(cone_search_ddf, ddf_partition_map, filtered_hc_structure) + def merge( + self, + other: Catalog, + how: str = "inner", + on: str | List | None = None, + left_on: str | List | None = None, + right_on: str | List | None = None, + left_index: bool = False, + right_index: bool = False, + suffixes: Tuple[str, str] | None = None, + ) -> dd.DataFrame: + """Performs the merge of two catalog Dataframes + + More information about pandas merge is available + `here `__. + + Args: + other (Catalog): The right catalog to merge with. + how (str): How to handle the merge of the two catalogs. + One of {'left', 'right', 'outer', 'inner'}, defaults to 'inner'. + on (str | List): Column or index names to join on. Defaults to the + intersection of columns in both Dataframes if on is None and not + merging on indexes. + left_on (str | List): Column to join on the left Dataframe. Lists are + supported if their length is one. + right_on (str | List): Column to join on the right Dataframe. Lists are + supported if their length is one. + left_index (bool): Use the index of the left Dataframe as the join key. + Defaults to False. + right_index (bool): Use the index of the right Dataframe as the join key. + Defaults to False. + suffixes (Tuple[str, str]): A pair of suffixes to be appended to the + end of each column name when they are joined. Defaults to using the + name of the catalog for the suffix. + + Returns: + A new Dask Dataframe containing the data points that result from the merge + of the two catalogs. + """ + if suffixes is None: + suffixes = (f"_{self.name}", f"_{other.name}") + if len(suffixes) != 2: + raise ValueError("`suffixes` must be a tuple with two strings") + return self._ddf.merge( + other._ddf, + how=how, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + suffixes=suffixes, + ) + def to_hipscat( self, base_catalog_path: str, diff --git a/tests/lsdb/catalog/test_merge.py b/tests/lsdb/catalog/test_merge.py new file mode 100644 index 00000000..17e81799 --- /dev/null +++ b/tests/lsdb/catalog/test_merge.py @@ -0,0 +1,78 @@ +import dask.dataframe as dd +import pandas as pd +import pytest + + +@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"]) +def test_catalog_merge_on_indices(small_sky_catalog, small_sky_order1_catalog, how): + kwargs = { + "how": how, + "left_index": True, + "right_index": True, + "suffixes": ("_left", "_right") + } + # Setting the object "id" for index on both catalogs + small_sky_catalog._ddf = small_sky_catalog._ddf.set_index("id") + small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.set_index("id") + # The wrapper outputs the same result as the underlying pandas merge + merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs) + assert isinstance(merged_ddf, dd.DataFrame) + expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs) + pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute()) + + +@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"]) +def test_catalog_merge_on_columns(small_sky_catalog, small_sky_order1_catalog, how): + kwargs = { + "how": how, + "on": "id", + "suffixes": ("_left", "_right") + } + # Make sure none of the test catalogs have "id" for index + small_sky_catalog._ddf = small_sky_catalog._ddf.reset_index() + small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.reset_index() + # The wrapper outputs the same result as the underlying pandas merge + merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs) + assert isinstance(merged_ddf, dd.DataFrame) + expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs) + pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute()) + + +@pytest.mark.parametrize("how", ["left", "right", "inner", "outer"]) +def test_catalog_merge_on_index_and_column(small_sky_catalog, small_sky_order1_catalog, how): + kwargs = { + "how": how, + "left_index": True, + "right_on": "id", + "suffixes": ("_left", "_right") + } + # Setting the object "id" for index on the left catalog + small_sky_catalog._ddf = small_sky_catalog._ddf.set_index("id") + # Make sure the right catalog does not have "id" for index + small_sky_order1_catalog._ddf = small_sky_order1_catalog._ddf.reset_index() + # The wrapper outputs the same result as the underlying pandas merge + merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, **kwargs) + assert isinstance(merged_ddf, dd.DataFrame) + expected_df = small_sky_catalog._ddf.merge(small_sky_order1_catalog._ddf, **kwargs) + pd.testing.assert_frame_equal(expected_df.compute(), merged_ddf.compute()) + + +def test_catalog_merge_invalid_suffixes(small_sky_catalog, small_sky_order1_catalog): + with pytest.raises(ValueError, match="`suffixes` must be a tuple with two strings"): + small_sky_catalog.merge( + small_sky_order1_catalog, how="inner", on="id", suffixes=("_left", "_middle", "_right") + ) + + +def test_catalog_merge_no_suffixes(small_sky_catalog, small_sky_order1_catalog): + merged_ddf = small_sky_catalog.merge(small_sky_order1_catalog, how="inner", on="id") + assert isinstance(merged_ddf, dd.DataFrame) + # Get the columns with the same name in both catalogs + non_join_columns_left = small_sky_catalog._ddf.columns.drop("id") + non_join_columns_right = small_sky_order1_catalog._ddf.columns.drop("id") + intersected_cols = list(set(non_join_columns_left) & set(non_join_columns_right)) + # The suffixes of these columns in the dataframe include the catalog names + suffixes = [f"_{small_sky_catalog.name}", f"_{small_sky_order1_catalog.name}"] + for column in intersected_cols: + for suffix in suffixes: + assert f"{column}{suffix}" in merged_ddf.columns