Skip to content

Commit

Permalink
Merge pull request #73 from hover2pi/v1.2.2
Browse files Browse the repository at this point in the history
v1.2.2
  • Loading branch information
hover2pi authored Aug 11, 2022
2 parents 544226c + dff379c commit 4f174c3
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 193 deletions.
2 changes: 1 addition & 1 deletion sedkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
__version_commit__ = ''
_regex_git_hash = re.compile(r'.*\+g(\w+)')

__version__ = '1.2.0'
__version__ = '1.2.2'

# from pkg_resources import get_distribution, DistributionNotFound
# try:
Expand Down
369 changes: 213 additions & 156 deletions sedkit/notebooks/Trappist-1.ipynb

Large diffs are not rendered by default.

70 changes: 56 additions & 14 deletions sedkit/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from astroquery.vizier import Vizier
from bokeh.plotting import figure, show
from scipy.optimize import least_squares
from scipy.interpolate import CubicSpline, UnivariateSpline
from bokeh.models.glyphs import Patch
from bokeh.models import ColumnDataSource
import numpy as np
Expand Down Expand Up @@ -116,7 +117,7 @@ def add_relation(self, rel_name, order, xrange=None, xunit=None, yunit=None, rej

# Set x range for fit
if xrange is not None:
idx = np.where(np.logical_and(rel['x'] > xrange[0], rel['x'] < xrange[1]))
idx = np.where(np.logical_and(rel['x'] > np.nanmin(xrange), rel['x'] < np.nanmax(xrange)))
rel['x'] = rel['x'][idx]
rel['y'] = rel['y'][idx]

Expand Down Expand Up @@ -200,7 +201,7 @@ def errFit(hess_inv, resVariance):
if plot:
show(self.plot(rel_name))

def evaluate(self, rel_name, x_val, plot=False):
def evaluate(self, rel_name, x_val, xunits=None, yunits=None, fit_local=False, plot=False):
"""
Evaluate the given relation at the given xval
Expand All @@ -210,6 +211,10 @@ def evaluate(self, rel_name, x_val, plot=False):
The relation name, i.e. 'yparam(xparam)'
x_val: float, int
The xvalue to evaluate
xunits: astropy.units.quantity.Quantity
The output x units
yunits: astropy.units.quantity.Quantity
The output y units
Returns
-------
Expand All @@ -230,33 +235,51 @@ def evaluate(self, rel_name, x_val, plot=False):
try:

# Get the relation
rel = self.relations[rel_name]
full_rel = self.relations[rel_name]
out_xunits = full_rel['xunit'].to(xunits) * xunits if xunits is not None else full_rel['xunit'] or 1
out_yunits = full_rel['yunit'].to(yunits) * yunits if yunits is not None else full_rel['yunit'] or 1

# Use local points for relation
if isinstance(fit_local, int):

# Trim relation data to nearby points, refit with low order polynomial, and evaluate
idx = np.argmin(np.abs(full_rel['x'] - (x_val[0] if isinstance(x_val, (list, tuple)) else x_val)))
x_min, x_max = full_rel['x'][max(0, idx - fit_local)], full_rel['x'][min(idx + fit_local, len(full_rel['x'])-1)]
self.add_relation('mass(Lbol)', 2, xrange=[x_min, x_max], yunit=full_rel['yunit'], reject_outliers=True, plot=False)
rel = self.relations[rel_name]

# ... or use the full relation
else:
rel = full_rel

# Evaluate the polynomial
if isinstance(x_val, (list, tuple)):

# With uncertainties
x = Unum(*x_val)
y = x.polyval(rel['coeffs'])
x_val = x.nominal
y_val = y.nominal * rel['yunit']
y_upper = y.upper * rel['yunit']
y_lower = y.lower * rel['yunit']
x_val = x.nominal * out_xunits
y_val = y.nominal * out_yunits
y_upper = y.upper * out_yunits
y_lower = y.lower * out_yunits

else:

# Without uncertainties
x_val = x_val.value if hasattr(x_val, 'unit') else x_val
y_val = np.polyval(rel['coeffs'], x_val) * rel['yunit']
x_val = x_val * out_xunits
y_val = np.polyval(rel['coeffs'], x_val) * out_yunits
y_lower = y_upper = None

if plot:
plt = self.plot(rel_name)
plt.circle([x_val], [y_val.value if hasattr(y_val, 'unit') else y_val], color='red', size=10, legend='{}({})'.format(rel['yparam'], x_val))
plt = self.plot(rel_name, xunits=xunits, yunits=yunits)
plt.circle([x_val.value if hasattr(x_val, 'unit') else x_val], [y_val.value if hasattr(y_val, 'unit') else y_val], color='red', size=10, legend='{}({})'.format(rel['yparam'], x_val))
if y_upper:
plt.line([x_val, x_val], [y_val - y_lower, y_val + y_upper], color='red')
show(plt)

# Restore full relation
self.relations[rel_name] = full_rel

if y_upper:
return y_val, y_upper, y_lower, self.ref
else:
Expand Down Expand Up @@ -292,9 +315,18 @@ def _parse_rel_name(self, rel_name):
"""
return rel_name.replace(')', '').split('(')[::-1]

def plot(self, rel_name, **kwargs):
def plot(self, rel_name, xunits=None, yunits=None, **kwargs):
"""
Plot the data for the given parameters
Parameters
----------
rel_name: str
The name of the relation
xunits: astropy.units.quantity.Quantity
The units to display
yunits: astropy.units.quantity.Quantity
The units to display
"""
# Get params
xparam, yparam = self._parse_rel_name(rel_name)
Expand All @@ -305,15 +337,18 @@ def plot(self, rel_name, **kwargs):
# Make the figure
fig = figure(x_axis_label=xparam, y_axis_label=yparam)
x, y, _ = self.validate_data(self.data[xparam], self.data[yparam])
fig.circle(x, y, legend='Data', **kwargs)

xu = 1
yu = 1
if rel_name in self.relations:

# Get the relation
rel = self.relations[rel_name]

# Plot polynomial values
fig.line(rel['x_fit'], rel['y_fit'], color='black', legend='Fit')
xu = rel['xunit'].to(xunits) if xunits is not None else rel['xunit'] or 1
yu = rel['yunit'].to(yunits) if yunits is not None else rel['yunit'] or 1
fig.line(rel['x_fit'] * xu, rel['y_fit'] * yu, color='black', legend='Fit')

# # Plot relation error
# xpat = np.hstack((rel['x_fit'], rel['x_fit'][::-1]))
Expand All @@ -322,6 +357,13 @@ def plot(self, rel_name, **kwargs):
# glyph = Patch(x='xaxis', y='yaxis', fill_color='black', line_color=None, fill_alpha=0.1)
# fig.add_glyph(err_source, glyph)

# Update axis labels
fig.xaxis.axis_label = '{}{}'.format(xparam, '[{}]'.format(xunits or rel['xunit']))
fig.yaxis.axis_label = '{}{}'.format(yparam, '[{}]'.format(yunits or rel['yunit']))

# Draw points
fig.circle(x * xu, y * yu, legend='Data', **kwargs)

return fig

def validate_data(self, X, Y, Y_unc=None):
Expand Down
28 changes: 18 additions & 10 deletions sedkit/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,14 @@ def add_spectrum(self, spectrum, **kwargs):
else:
raise ValueError('Input spectrum must be [W,F] or [W,F,E].')

# also ok if specutils.spectra.spectrum1d.Spectrum1D
elif hasattr(spectrum, 'spectral_axis'):

if spectrum.uncertainty is None:
spec = sp.Spectrum(spectrum.spectral_axis, spectrum.flux, ** kwargs)
else:
spec = sp.Spectrum(spectrum.spectral_axis, spectrum.flux, unc=spectrum.uncertainty.array * spectrum.uncertainty.unit, **kwargs)

# or it's no good
else:
raise TypeError('Must enter [W,F], [W,F,E], or a Spectrum object')
Expand Down Expand Up @@ -1106,13 +1114,13 @@ def export(self, parentdir='.', dirname=None, zipped=False):
res_table.write(resultspath, format='ipac')

# The SED plot
if self.fig is not None:
try:
pltopath = os.path.join(dirpath, '{}_plot.png'.format(name))
export_png(self.fig, filename=pltopath)
except:
# Bokeh dropped support for PhantomJS so image saving is now browser dependent and fails occasionally
self.message("Could not export SED for {}".format(self.name))
# if self.fig is not None:
# try:
# pltopath = os.path.join(dirpath, '{}_plot.png'.format(name))
# export_png(self.fig, filename=pltopath)
# except Exception:
# # Bokeh dropped support for PhantomJS so image saving is now browser dependent and fails occasionally
# self.message("Could not export SED for {}".format(self.name))

# zip if desired
if zipped:
Expand Down Expand Up @@ -1835,19 +1843,19 @@ def infer_mass(self, mass_units=q.Msun, isochrone=False, plot=False):
if self.mass is None and self.Lbol_sun is not None:

# Infer from Dwarf Sequence
self.mass = self.mainsequence.evaluate('mass(Lbol)', self.Lbol_sun, plot=plot)
self.mass = self.mainsequence.evaluate('mass(Lbol)', self.Lbol_sun, fit_local=5, yunits=mass_units, plot=plot)

# Try mass(M_J) relation
elif self.mass is None and self.get_mag('2MASS.J', 'abs') is not None:

# Infer from Dwarf Sequence
self.mass = self.mainsequence.evaluate('mass(M_J)', self.get_mag('2MASS.J'), plot=plot)
self.mass = self.mainsequence.evaluate('mass(M_J)', self.get_mag('2MASS.J'), fit_local=5, yunits=mass_units, plot=plot)

# Try mass(M_Ks) relation
elif self.mass is None and self.get_mag('2MASS.Ks', 'abs') is not None:

# Infer from Dwarf Sequence
self.mass = self.mainsequence.evaluate('mass(M_J)', self.get_mag('2MASS.Ks'), plot=plot)
self.mass = self.mainsequence.evaluate('mass(M_J)', self.get_mag('2MASS.Ks'), fit_local=5, yunits=mass_units, plot=plot)

# No dice
else:
Expand Down
56 changes: 44 additions & 12 deletions sedkit/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pandas import DataFrame
from svo_filters import Filter


from . import mcmc as mc
from . import utilities as u

Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(self, wave, flux, unc=None, snr=None, trim=None, const=1, phot=Fals
raise TypeError("Wavelength array must be in astropy.units.quantity.Quantity length units, e.g. 'um'")

# Check flux units are flux density
if not u.equivalent(flux, u.FLAM):
if not u.equivalent(flux, (u.FLAM, q.Jy)):
raise TypeError("Flux array must be in astropy.units.quantity.Quantity flux density units, e.g. 'erg/s/cm2/A'")

# Generate uncertainty array
Expand All @@ -104,7 +105,7 @@ def __init__(self, wave, flux, unc=None, snr=None, trim=None, const=1, phot=Fals

# Make sure the uncertainty array is in the correct units
if unc is not None:
if not u.equivalent(unc, u.FLAM):
if not u.equivalent(unc, (u.FLAM, q.Jy)):
raise TypeError("Uncertainty array must be in astropy.units.quantity.Quantity flux density units, e.g. 'erg/s/cm2/A'")

# Replace negatives, zeros, and infs with nans
Expand Down Expand Up @@ -548,25 +549,44 @@ def flux_units(self, flux_units):
The astropy units of the SED wavelength
"""
# Check the units
if not u.equivalent(flux_units, u.FLAM):
raise TypeError("flux_units must be in flux density units, e.g. 'erg/s/cm2/A'")
if not u.equivalent(flux_units, (u.FLAM, q.Jy)):
raise TypeError("flux_units must be in flux density units, e.g. 'erg/s/cm2/A' or 'Jy'")

# Update the flux and unc arrays
self._flux = self._flux * self.flux_units.to(flux_units)
if self.unc is not None:
self._unc = self._unc * self.flux_units.to(flux_units)
# Check if units are Fnu of Flam
if u.equivalent(flux_units, q.Jy) and u.equivalent(self.flux_units, u.FLAM):

# Convert native FLAM units to Jy
self._flux = self._flux * self.wave ** 2 * 3.34e-19
if self.unc is not None:
self._unc = self._unc * self.wave ** 2 * 3.34e-19

elif u.equivalent(self.flux_units, q.Jy) and u.equivalent(flux_units, u.FLAM):

# Convert native Jy units to FLAM
self._flux = self._flux * 3e18 / (self.wave ** 2)
if self.unc is not None:
self._unc = self._unc * 3e18 / (self.wave ** 2)

else:

# Update the flux and unc arrays
self._flux = self._flux * self.flux_units.to(flux_units)
if self.unc is not None:
self._unc = self._unc * self.flux_units.to(flux_units)

# Set the flux_units
self._flux_units = flux_units
self._set_units()

def integrate(self, units=q.erg / q.s / q.cm**2):
def integrate(self, units=q.erg / q.s / q.cm**2, n_samples=10):
"""Calculate the area under the spectrum
Parameters
----------
units: astropy.units.quantity.Quantity
The target units for the integral
n_samples: int
The number of samples to take for uncertainty calculations
Returns
-------
Expand All @@ -585,11 +605,23 @@ def integrate(self, units=q.erg / q.s / q.cm**2):
val = (np.trapz(spec[1], x=spec[0]) * m).to(units)

if self.unc is None:
unc = None
vunc = None
else:
unc = np.sqrt(np.nansum((spec[2] * np.gradient(spec[0]) * m)**2)).to(units)

return val, unc
# Sum approximation to integral
# vunc = np.sqrt(np.nansum((spec[2] * np.gradient(spec[0]) * m)**2)).to(units)

# Bootstrap errors
uvals = []
for n in range(n_samples):
usamp = np.random.normal(spec[1], spec[2])
err = (np.trapz(usamp, x=spec[0]) * m).to(units)
uvals.append(abs(err.value - val.value))

# Get 1-sigma of distribution
vunc = np.max(abs(np.asarray(uvals))) * units

return val, vunc

@copy_raw
def interpolate(self, wave):
Expand Down
12 changes: 12 additions & 0 deletions sedkit/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def test_units(self):
self.assertEqual(s.spectrum[1].unit, fu)
self.assertEqual(s.spectrum[2].unit, fu)

def test_unit_conversion(self):
"""Test Fnu support works"""
s = copy.copy(self.spec)

# Change the flux units
fu = q.Jy
s.flux_units = fu

# Make sure the units are being updated
self.assertEqual(s.spectrum[1].unit, fu)
self.assertEqual(s.spectrum[2].unit, fu)

def test_model_fit(self):
"""Test that a model grid can be fit"""
# Empty fit results
Expand Down

0 comments on commit 4f174c3

Please sign in to comment.