Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plotting to search filters #524

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion src/lsdb/core/search/abstract_search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple, Type

import astropy
import nested_pandas as npd
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.visualization.wcsaxes import WCSAxes
from astropy.visualization.wcsaxes.frame import BaseFrame
from hats.catalog import TableProperties
from hats.inspection.visualize_catalog import initialize_wcs_axes
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from mocpy import MOC

if TYPE_CHECKING:
Expand Down Expand Up @@ -32,3 +40,62 @@
@abstractmethod
def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
"""Determine the search results within a data frame"""

def plot(
self,
projection: str = "MOL",
title: str = "",
fov: Quantity | Tuple[Quantity, Quantity] | None = None,
center: SkyCoord | None = None,
wcs: astropy.wcs.WCS | None = None,
frame_class: Type[BaseFrame] | None = None,
ax: WCSAxes | None = None,
fig: Figure | None = None,
**kwargs,
):
"""Plot the search region

Args:
projection (str): The projection to use in the WCS. Available projections listed at
https://docs.astropy.org/en/stable/wcs/supported_projections.html
title (str): The title of the plot
fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an
astropy Quantity with an angular unit, or a tuple of quantities for different longitude and
latitude FOVs (Default covers the full sky)
center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0))
wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters
are ignored and the parameters from the WCS object is used.
frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized
with. if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for
full sky projection. If FOV is set, RectangularFrame is used)
ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be
used. If specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set
with the WCS object used in the axes. (Default: None)
fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created,
unless ax is specified (Default: None)
**kwargs: Additional kwargs to pass to creating the matplotlib patch object for the search region

Returns:
Tuple[Figure, WCSAxes] - The figure and axes used for the plot
"""
fig, ax, wcs = initialize_wcs_axes(
projection=projection,
fov=fov,
center=center,
wcs=wcs,
frame_class=frame_class,
ax=ax,
fig=fig,
figsize=(9, 5),
)
self._perform_plot(ax, **kwargs)

plt.grid()
plt.ylabel("Dec")
plt.xlabel("RA")
plt.title(title)
return fig, ax

def _perform_plot(self, ax: WCSAxes, **kwargs):
"""Perform the plot of the search region on an initialized WCSAxes"""
raise NotImplementedError("Plotting has not been implemented for this search")

Check warning on line 101 in src/lsdb/core/search/abstract_search.py

View check run for this annotation

Codecov / codecov/patch

src/lsdb/core/search/abstract_search.py#L101

Added line #L101 was not covered by tests
11 changes: 11 additions & 0 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import astropy.units as u
import nested_pandas as npd
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes import SphericalCircle, WCSAxes
from hats.catalog import TableProperties
from hats.pixel_math.validators import validate_declination_values, validate_radius
from mocpy import MOC
Expand Down Expand Up @@ -31,6 +33,15 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np
"""Determine the search results within a data frame"""
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata)

def _perform_plot(self, ax: WCSAxes, **kwargs):
circle = SphericalCircle(
(self.ra * u.deg, self.dec * u.deg),
self.radius_arcsec * u.arcsec,
transform=ax.get_transform("icrs"),
**kwargs,
)
ax.add_patch(circle)


def cone_filter(data_frame: npd.NestedFrame, ra, dec, radius_arcsec, metadata: TableProperties):
"""Filters a dataframe to only include points within the specified cone
Expand Down
13 changes: 13 additions & 0 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import pandas as pd
import pytest
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes import SphericalCircle
from hats.pixel_math.validators import ValidatorsErrors

from lsdb import ConeSearch


def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
ra = 0
Expand Down Expand Up @@ -124,3 +127,13 @@ def test_empty_cone_search_with_margin(small_sky_order1_source_with_margin):
cone = small_sky_order1_source_with_margin.cone_search(ra, dec, radius, fine=False)
assert len(cone._ddf_pixel_map) == 0
assert len(cone.margin._ddf_pixel_map) == 0


def test_cone_search_plot():
ra = 100
dec = 80
radius = 60
search = ConeSearch(ra, dec, radius)
_, ax = search.plot()
assert len(ax.patches) == 1
assert isinstance(ax.patches[0], SphericalCircle)
Loading