diff --git a/mosviz/viewers/mos_viewer.py b/mosviz/viewers/mos_viewer.py index ebf0956..2ecdda3 100644 --- a/mosviz/viewers/mos_viewer.py +++ b/mosviz/viewers/mos_viewer.py @@ -19,6 +19,9 @@ from specutils.core.generic import Spectrum1DRef from astropy.nddata.nduncertainty import StdDevUncertainty +from astropy.units import Unit +from astropy.wcs import WCS + try: from specviz.external.glue.data_viewer import SpecVizViewer @@ -47,7 +50,7 @@ def load_ui(self): """ Setup the MOSView viewer interface. """ - self.central_widget = QWidget() + self.central_widget = QWidget(self) path = os.path.abspath( os.path.join(os.path.dirname(__file__), @@ -246,7 +249,6 @@ def _unpack_selection(self, data): # Clear the table self.catalog = Table() - self.catalog.meta = data.meta col_names = data.components for att in col_names: @@ -273,22 +275,31 @@ def _unpack_selection(self, data): self.catalog[str(att)] = comp_data - # Update gui elements - self._update_navigation() - self._set_navigation(0) + if len(self.catalog) > 0: + # Load the first source in the catalog + self.load_selection(self.catalog[0]) - # Load the first source in the catalog - self.load_selection(self.catalog[0]) + # Update gui elements + self._update_navigation() + self._set_navigation(0) def _update_navigation(self): """ Updates the :class:`qtpy.QtWidgets.QComboBox` widget with the appropriate source `id`s from the MOS catalog. """ + if self.toolbar is None: + return + self.toolbar.source_select.clear() - self.toolbar.source_select.addItems(self.catalog['id'][:]) + + if len(self.catalog) > 0 and 'id' in self.catalog.colnames: + self.toolbar.source_select.addItems(self.catalog['id'][:]) def _set_navigation(self, index): + if len(self.catalog) < index: + return + if 0 <= index < self.toolbar.source_select.count(): self.toolbar.source_select.setCurrentIndex(index) self.load_selection(self.catalog[index]) @@ -304,7 +315,7 @@ def _set_navigation(self, index): self.toolbar.cycle_next_action.setDisabled(False) def _get_loaders(self): - loaders = self.catalog.meta.get("loaders", []) + loaders = self.catalog.meta.get("loaders", {}) # if loader is specified if "spec1d" in loaders: spectrum1d_loader = next((x.function for x in config.data_factory.members if x.label == self.catalog.meta["loaders"]["spec1d"]), @@ -326,9 +337,8 @@ def _get_loaders(self): return spectrum1d_loader, spectrum2d_loader, cutout_loader def _open_in_specviz(self): - if self._specviz_instance is None: - self._specviz_instance = self.session.application.new_data_viewer( - SpecVizViewer) + _specviz_instance = self.session.application.new_data_viewer( + SpecVizViewer) spec1d_data = self._loaded_data['spec1d'] @@ -336,9 +346,10 @@ def _open_in_specviz(self): data=spec1d_data.get_component(spec1d_data.id['Flux']).data, dispersion=spec1d_data.get_component(spec1d_data.id['Wavelength']).data, uncertainty=StdDevUncertainty(spec1d_data.get_component(spec1d_data.id['Uncertainty']).data), - unit="", name=self.current_row['id']) + unit="", name=self.current_row['id'], + wcs=WCS(spec1d_data.header)) - self._specviz_instance.open_data(spec_data) + _specviz_instance.open_data(spec_data) def load_selection(self, row): """ @@ -398,8 +409,22 @@ def render_data(self, row, spec1d_data=None, spec2d_data=None, y=spec1d_data.get_component(spec1d_data.id['Flux']).data, yerr=spec1d_data.get_component(spec1d_data.id['Uncertainty']).data) - self.spectrum1d_widget.axes.set_xlabel("Wavelength") - self.spectrum1d_widget.axes.set_ylabel("Flux") + # Try to retrieve the wcs information + try: + flux_unit = spec1d_data.header.get('BUNIT', 'Jy').lower() + flux_unit = flux_unit.replace('counts', 'count') + flux_unit = Unit(flux_unit) + except ValueError: + flux_unit = Unit("Jy") + + try: + disp_unit = spec1d_data.header.get('CUNIT1', 'Angstrom').lower() + disp_unit = Unit(disp_unit) + except ValueError: + disp_unit = Unit("Angstrom") + + self.spectrum1d_widget.axes.set_xlabel("Wavelength [{}]".format(disp_unit)) + self.spectrum1d_widget.axes.set_ylabel("Flux [{}]".format(flux_unit)) if spec2d_data is not None: wcs = spec2d_data.coords.wcs @@ -407,7 +432,8 @@ def render_data(self, row, spec1d_data=None, spec2d_data=None, self.spectrum2d_widget.set_image( image=spec2d_data.get_component( spec2d_data.id['Flux']).data, - wcs=wcs, interpolation='none', aspect='auto') + wcs=wcs, interpolation='none', aspect='auto', + header=spec2d_data.header) self.spectrum2d_widget.axes.set_xlabel("Wavelength") self.spectrum2d_widget.axes.set_ylabel("Spatial Y") diff --git a/mosviz/widgets/plots.py b/mosviz/widgets/plots.py index 0a405e8..f2e0839 100644 --- a/mosviz/widgets/plots.py +++ b/mosviz/widgets/plots.py @@ -9,8 +9,11 @@ from glue.viewers.common.qt.toolbar import BasicToolbar from glue.viewers.common.qt.mpl_toolbar import MatplotlibViewerToolbar +import numpy as np +from astropy.wcs import WCS, WCSSUB_SPECTRAL + from matplotlib import rcParams -rcParams.update({'figure.autolayout': True}) +# rcParams.update({'figure.autolayout': True}) __all__ = ['Line1DWidget', 'ShareableAxesImageWidget', 'DrawableImageWidget'] @@ -63,8 +66,37 @@ def set_data(self, x, y, yerr=None): def _redraw(self): self.central_widget.canvas.draw() + def set_status(self): + pass + + +class MOSImageWidget(StandaloneImageWidget): + def __init__(self, *args, **kwargs): + super(MOSImageWidget, self).__init__(*args, **kwargs) + + def set_image(self, image=None, wcs=None, header=None, **kwargs): + super(MOSImageWidget, self).set_image(image, wcs, **kwargs) + + if header is not None: + hwcs = WCS(header) + + # Try to reference the spectral axis + hwcs_spec = hwcs.sub([WCSSUB_SPECTRAL]) + + # Check to see if it actually is a real coordinate description + if hwcs_spec.naxis == 0: + # It's not real, so attempt to get the spectral axis by + # specifying axis by integer + hwcs_spec = hwcs.sub([hwcs.naxis]) + + # Construct the dispersion array + dispersion = hwcs_spec.all_pix2world( + np.arange(image.shape[0]), 0)[0] + + self.axes.set_xticklabels(["{}".format(x) for x in dispersion]) + -class ShareableAxesImageWidget(StandaloneImageWidget): +class ShareableAxesImageWidget(MOSImageWidget): def __init__(self, *args, **kwargs): super(ShareableAxesImageWidget, self).__init__(*args, **kwargs) @@ -96,11 +128,11 @@ def set_status(self): pass -class DrawableImageWidget(StandaloneImageWidget): +class DrawableImageWidget(MOSImageWidget): def __init__(self, *args, **kwargs): super(DrawableImageWidget, self).__init__(*args, **kwargs) self._slit_patch = None def draw_shapes(self, x=0, y=0, width=100, length=100): self._slit_patch = plt.Rectangle((x-length, y-width), width, length, fc='r') - self.axes.add_patch(self._slit_patch) + # self.axes.add_patch(self._slit_patch)