diff --git a/sedkit/isochrone.py b/sedkit/isochrone.py index 576c725..67241b2 100644 --- a/sedkit/isochrone.py +++ b/sedkit/isochrone.py @@ -20,6 +20,7 @@ from . import utilities as u from . import uncertainties as un + # A dictionary of all supported moving group ages from Bell et al. (2015) NYMG_AGES = {'AB Dor': (149 * q.Myr, 51 * q.Myr, '2015MNRAS.454..593B'), 'beta Pic': (24 * q.Myr, 3 * q.Myr, '2015MNRAS.454..593B'), @@ -35,7 +36,7 @@ EVO_MODELS = [os.path.basename(m).replace('.txt', '') for m in glob.glob(resource_filename('sedkit', 'data/models/evolutionary/*'))] # Fails RTD build for some reason except: - EVO_MODELS = ['COND03', 'dmestar_solar', 'DUSTY00', 'f2_solar_age', 'hybrid_solar_age', 'nc+03_age', 'nc-03_age', 'nc_solar_age', 'parsec12_solar', 'ATMO_NEQ_strong'] + EVO_MODELS = ['COND03', 'dmestar_solar', 'DUSTY00', 'f2_solar_age', 'hybrid_solar_age', 'nc+03_age','nc+0.0_age', 'nc-03_age', 'nc_solar_age','nc+0.5_age','nc-0.5_age', 'parsec12_solar','ATMO_NEQ_strong','ATMO_NEQ_strong_MIRI','chaubrier_2022.txt'] class Isochrone: """A class to handle model isochrones""" @@ -154,10 +155,8 @@ def evaluate(self, xval, age, xparam, yparam, plot=False): age = age[0], age[1].to(age[0].unit) # Check if the xval has an uncertainty - no_xerr = False if not isinstance(xval, (tuple, list)): xval = (xval, xval * 0) - no_xerr = True # Test the age range is inbounds if age[0] < self.ages.min() or age[0] > self.ages.max(): @@ -165,50 +164,65 @@ def evaluate(self, xval, age, xparam, yparam, plot=False): self.message('{}: age must be between {} and {} to infer {} from {} isochrones.'.format(*args)) return None - # x-errors required for MC routine - if no_xerr: - average = self.interpolate(xval[0], age[0], xparam, yparam) - error = average * 0 - # Monte Carlo Approach - else: - mu, sigma = xval[0], xval[1] # mean and standard deviation for values on the x-axis - mu_a, sigma_a = age[0].value, age[1].value # mean and standard deviation for the age range provided - xsample = np.random.normal(mu, sigma, 1000) - ysample = np.random.normal(mu_a, sigma_a, 1000) - values_list = [] - nan_counter = 0 - - for x, y in zip(xsample, ysample): - result = self.interpolate(x, y * age[0].unit, xparam, yparam) - if result is not None: - values_list.append(result.value if hasattr(result, 'unit') else result) - elif result is None: - print("Interpolate resulted in NaN") - nan_counter = nan_counter + 1 - - # Get final values - unit = self.data[yparam].unit or 1 - average = np.mean(np.array(values_list) * unit) - std = np.std(values_list) - error = std * unit + mu, sigma = xval[0], xval[1] # mean and standard deviation for values on the x-axis + mu_a, sigma_a = age[0].value, age[1].value # mean and standard deviation for the age range provided + + xsample = np.random.normal(mu, sigma, 10000) + ysample = np.random.uniform(mu_a-sigma_a, mu_a+sigma_a, 10000) + values_list = [] + nan_counter = 0 + + for x, y in zip(xsample, ysample): + result = self.interpolate(x, y * age[0].unit, xparam, yparam) + if result is not None: + values_list.append(result.value if hasattr(result, 'unit') else result) + elif result is None: + print("Interpolate resulted in NaN") + nan_counter = nan_counter + 1 + + # Get final values + unit = self.data[yparam].unit or 1 + + #Error Propagation + dist_hist, hist_edges = np.histogram(values_list) + + #Normalize Histogram + dist_hist = dist_hist/ np.sum(dist_hist) + # cumulative PDF + cum_PDF_hist_x = hist_edges[1:] # right edges of the bins + cum_PDF_hist = np.zeros(hist_edges.size - 1) + for i in range(hist_edges.size - 1): + cum_PDF_hist[i] = np.sum(dist_hist[:i + 1]) + + # interquartile range of the PDF + error_lower_per = 0.16 + error_upper_per = 0.84 + + Q1_hist = np.interp(error_lower_per, cum_PDF_hist, cum_PDF_hist_x) # lower error + Q2_hist = np.interp(0.50, cum_PDF_hist, cum_PDF_hist_x ) # median + Q3_hist = np.interp(error_upper_per, cum_PDF_hist, cum_PDF_hist_x) # upper error + + + lower_err = (Q2_hist - Q1_hist) * unit + upper_err = (Q3_hist - Q2_hist) * unit + average = Q2_hist * unit + # Plot the figure and evaluated point if plot: - val = average.value if hasattr(average, 'unit') else average - err = error.value if hasattr(error, 'unit') else error fig = self.plot(xparam, yparam) - legend = '{} = {:.3f} ({:.3f})'.format(yparam, val, err) - fig.circle(xval[0], val, color='red', legend_label=legend) - u.errorbars(fig, [xval[0]], [val], xerr=[xval[1]*2], yerr=[err], color='red') - + legend = '{} = {:.3f} ({:-.3f} {:+.3f})'.format(yparam, average.value, lower_err.value,upper_err.value) + fig.circle(xval[0], average.value, color='purple', legend_label=legend) + u.errorbars(fig, [xval[0]], [average.value], xerr=[xval[1]*2], yerr=[lower_err.value*-1,upper_err.value], color='purple') show(fig) - # Balk at nans + # # Balk at nans if np.isnan(average): + print("I got a nan from {} for some reason.".format(self.name)) raise ValueError("I got a nan from {} for some reason.".format(self.name)) + return average, lower_err, upper_err, self.name - return average, error, self.name def interpolate(self, xval, age, xparam, yparam): """Interpolate a value between two isochrones @@ -260,6 +274,7 @@ def interpolate(self, xval, age, xparam, yparam): result = np.average([lower_val, upper_val], weights=weights) unit = self.data[yparam].unit or 1 + # return result return result * unit def message(self, msg, pre='[sedkit]'): @@ -315,6 +330,7 @@ def plot(self, xparam, yparam, draw=False, labels=True, **kwargs): color_mapper = LinearColorMapper(palette='Viridis256', low=self.ages.min().value, high=self.ages.max().value) + color_bar = ColorBar(color_mapper=color_mapper, ticker=BasicTicker(), label_standoff=5, border_line_color=None, title='Age [{}]'.format(self.age_units), diff --git a/sedkit/sed.py b/sedkit/sed.py index ca65bcf..b6c33d3 100644 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -8,13 +8,13 @@ Author: Joe Filippazzo, jfilippazzo@stsci.edu """ - from copy import copy import os import shutil import time import warnings +import itertools import astropy.table as at import astropy.units as q import astropy.io.ascii as ii @@ -513,7 +513,7 @@ def age(self, age): age: sequence The age and uncertainty in distance units """ - self._validate_and_set_param('age', age, q.Gyr, True, vmin=0 * q.Myr, vmax=13.8 * q.Gyr) + self._validate_and_set_param('age', age, q.Gyr, True, vmin=0 * q.Gyr, vmax=13.8 * q.Gyr) def _calculate_sed(self): """ @@ -596,6 +596,7 @@ def calculate_Lbol(self): Lbol_unc = Lbol * np.sqrt((self.fbol[1] / self.fbol[0]).value**2 + (2 * self.distance[1] / self.distance[0]).value**2) Lbol_sun_unc = round(abs(Lbol_unc / (Lbol * np.log(10))).value, 3) + # Update the attributes self.Lbol = Lbol, Lbol_unc, 'This Work' self.Lbol_sun = Lbol_sun, Lbol_sun_unc, 'This Work' @@ -711,9 +712,10 @@ def calculate_Teff(self): Teff_unc = None else: Teff_unc = (Teff * np.sqrt((self.Lbol[1] / self.Lbol[0]).value**2 + (2 * self.radius[1] / self.radius[0]).value**2) / 4.).astype(int) + Teff_unc_u = (Teff * np.sqrt((self.Lbol[1] / self.Lbol[0]).value**2 + (2 * self.radius[2] / self.radius[0]).value**2) / 4.).astype(int) # Update the attribute - self.Teff = Teff, Teff_unc, 'This Work' + self.Teff = Teff, Teff_unc,Teff_unc_u, 'This Work' def _calibrate_photometry(self, name='photometry'): """ @@ -750,19 +752,19 @@ def _calibrate_photometry(self, name='photometry'): # Calculate absolute mags and fluxes if self.distance is not None: - - for n, row in enumerate(table): - - # Calculate abs_mags - M, M_unc = u.flux_calibrate(row['app_magnitude'], self.distance[0], row['app_magnitude_unc'], self.distance[1]) - table['abs_magnitude'][n] = M - table['abs_magnitude_unc'][n] = M_unc - - # Calculate abs_flux values - abs_flux, abs_flux_unc = u.mag2flux(row['bandpass'], row['abs_magnitude'], sig_m=row['abs_magnitude_unc'], units=self.flux_units) - table['abs_flux'][n] = abs_flux - table['abs_flux_unc'][n] = abs_flux_unc - + + # Calculate abs_mags + for n,row in enumerate(table): + M, M_unc = u.flux_calibrate(row['app_magnitude'], self.distance[0], row['app_magnitude_unc'], self.distance[1]) + table['abs_magnitude'][n] = M + table['abs_magnitude_unc'][n] = M_unc + + # Calculate abs_flux values + for n, row in enumerate(table): + abs_flux, abs_flux_unc = u.mag2flux(row['bandpass'], row['abs_magnitude'], sig_m=row['abs_magnitude_unc'], units=self.flux_units) + table['abs_flux'][n] = abs_flux + table['abs_flux_unc'][n] = abs_flux_unc + # Store the table setattr(self, '_{}'.format(name), table) @@ -1103,13 +1105,14 @@ def export(self, parentdir='.', dirname=None, zipped=False): # Absolute spectral SED if self.abs_spec_SED is not None: specpath = os.path.join(dirpath, '{}_absolute_SED.txt'.format(name)) - header = '{} absolute spectrum (erg/s/cm2/A) as a function of wavelength (um)'.format(name) + header = '{} absolute spectrum (erg/s/cm2/A) at 10 pc as a function of wavelength (um)'.format(name) self.abs_spec_SED.export(specpath, header=header) # All photometry if self.photometry is not None: photpath = os.path.join(dirpath, '{}_photometry.txt'.format(name)) phot_table = copy(self.photometry) + # print(phot_table) for colname in phot_table.colnames: phot_table.rename_column(colname, colname.replace('/', '_')) phot_table.write(photpath, format='ipac') @@ -1820,8 +1823,7 @@ def infer_logg(self, plot=False): self.message("Could not calculate surface gravity.") # Store the value - self.logg = [logg[0].round(2), logg[1].round(2), logg[2]] if logg is not None else logg - + self.logg = [logg[0].round(2), logg[1].round(2),logg[2].round(2),logg[3]] if logg is not None else logg # No dice else: self.message('Could not calculate logg without Lbol and age') @@ -1856,7 +1858,7 @@ def infer_mass(self, mass_units=q.Msun, plot=False): mass = self.evo_model.evaluate(self.Lbol_sun, self.age, 'Lbol', 'mass', plot=plot) # Store the value - self.mass = [mass[0].round(3), mass[1].round(3), mass[2]] if mass is not None else mass + self.mass = [mass[0].round(0), mass[1].round(0), mass[2].round(0), mass[3]] if mass is not None else mass else: @@ -1883,7 +1885,7 @@ def infer_mass(self, mass_units=q.Msun, plot=False): self.message('Could not calculate mass without Lbol, M_2MASS.J, or M_2MASS.Ks') - def infer_radius(self, radius_units=q.Rsun, infer_from=None, plot=False): + def infer_radius(self, radius_units=q.Rsun, infer_from='Lbol', plot=False): """ Estimate the radius from model isochrones given an age and Lbol @@ -1940,7 +1942,7 @@ def infer_radius(self, radius_units=q.Rsun, infer_from=None, plot=False): radius = self.evo_model.evaluate(self.Lbol_sun, self.age, 'Lbol', 'radius', plot=plot) # Store the value - self.radius = [radius[0].round(3), radius[1].round(3), radius[2]] if radius is not None else radius + self.radius = [radius[0].round(3), radius[1].round(3), radius[2].round(3), radius[3]] if radius is not None else radius # Try radius(spt) relation elif infer_from == 'spt': @@ -2043,7 +2045,7 @@ def infer_Teff(self, teff_units=q.K, plot=False): teff = self.evo_model.evaluate(self.Lbol_sun, self.age, 'Lbol', 'teff', plot=plot) # Store the value - self.Teff = [teff[0].round(0), teff[1].round(0), teff[2]] if teff is not None else teff + self.Teff = [teff[0].round(0), teff[1].round(0), teff[2].round(0), teff[3]] if teff is not None else teff else: @@ -2497,9 +2499,9 @@ def photometry(self): self._photometry.sort('eff') return self._photometry - def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic_photometry=False, + def plot(self, app=False, photometry=True, spectra=True, integral=True, synthetic_photometry=False, best_fit=True, normalize=None, scale=['log', 'log'], output=False, fig=None, - color='#1f77b4', one_color=False, label=None, **kwargs): + color='#3a243b', one_color=False, label=None, **kwargs): """ Plot the SED @@ -2543,7 +2545,10 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic self.make_sed() # Distinguish between apparent and absolute magnitude - pre = 'app_' if app else 'abs_' + if self.distance is None: + pre = 'app_' + else: + pre = 'app_' if app else 'abs_' # Calculate reasonable axis limits full_SED = getattr(self, pre + 'SED') @@ -2586,7 +2591,8 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic # ...or make a new plot else: - TOOLS = ['pan', 'reset', 'box_zoom', 'wheel_zoom', 'save'] + # TOOLS = ['pan', 'reset', 'box_zoom', 'wheel_zoom', 'save'] + TOOLS = ['pan', 'reset', 'box_zoom', 'save'] xlab = 'Wavelength [{}]'.format(self.wave_units) ylab = 'Flux Density [{}]'.format(str(self.flux_units)) self.fig = figure(width=900, height=400, title=self.name, @@ -2597,14 +2603,12 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic # Plot spectra if spectra and len(self.spectra) > 0: - if spectra == 'all': - + if (spectra == 'all' and app is False): for n, spec in enumerate(self.spectra['spectrum']): - - # Normalize to the integral - norm_spec = spec.norm_to_spec(full_SED) - self.fig.line(norm_spec.wave, norm_spec.flux * const, color=next(COLORS), alpha=0.8, name='spectra', legend_label=spec.name) - + self.fig = spec.plot(fig=self.fig, components=True, const=(spec.flux_calibrate(self.distance).flux/spec.flux)) + elif (spectra == 'all' and app is True): + for n, spec in enumerate(self.spectra['spectrum']): + self.fig = spec.plot(fig=self.fig, components=True, const=const) else: self.fig.line(spec_SED.wave, spec_SED.flux * const, color=color, alpha=0.8, name='spectrum', legend_label='Spectrum') @@ -2613,29 +2617,36 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic # Set up hover tool phot_tips = [('Band', '@desc'), ('Wave', '@x'), ('Flux', '@y'), ('Unc', '@z')] - # hover = HoverTool(name=['photometry', 'nondetection'], tooltips=phot_tips) - hover = HoverTool(name='photometry', tooltips=phot_tips) - self.fig.add_tools(hover) # Plot points with errors pts = np.array([(bnd, wav, flx * const, err * const) for bnd, wav, flx, err in np.array(self.photometry['band', 'eff', pre + 'flux', pre + 'flux_unc']) if not any([(np.isnan(i) or i <= 0) for i in [wav, flx, err]])], dtype=[('desc', 'S20'), ('x', float), ('y', float), ('z', float)]) if len(pts) > 0: source = ColumnDataSource(data=dict(x=pts['x'], y=pts['y'], z=pts['z'], desc=[b.decode("utf-8") for b in pts['desc']])) - self.fig.circle('x', 'y', source=source, legend_label='Photometry', name='photometry', color=color, fill_alpha=0.7, size=8) + c1 = self.fig.circle('x', 'y', source=source, legend_label='Photometry', name='photometry', color='#a64d79', fill_alpha=0.7, size=8) + self.fig.circle('x', 'y', source=source, legend_label='Photometry', name='photometry', color='#a64d79', fill_alpha=0.7, size=8) self.fig = u.errorbars(self.fig, 'x', 'y', yerr='z', source=source, color=color) + hover = HoverTool(tooltips=phot_tips, renderers=[c1]) + self.fig.add_tools(hover) + + # Plot points without errors pts = np.array([(bnd, wav, flx * const, err * const) for bnd, wav, flx, err in np.array(self.photometry['band', 'eff', pre + 'flux', pre + 'flux_unc']) if (np.isnan(err) or err <= 0) and not np.isnan(flx)], dtype=[('desc', 'S20'), ('x', float), ('y', float), ('z', float)]) if len(pts) > 0: source = ColumnDataSource(data=dict(x=pts['x'], y=pts['y'], z=pts['z'], desc=[b.decode("utf-8") for b in pts['desc']])) + c2 = self.fig.circle('x', 'y', source=source, legend_label='Limit', name='nondetection', color=color,fill_alpha=0, size=8) self.fig.circle('x', 'y', source=source, legend_label='Limit', name='nondetection', color=color, fill_alpha=0, size=8) + hover1 = HoverTool(tooltips=phot_tips, renderers=[c2]) + self.fig.add_tools(hover1) + # Plot photometry if synthetic_photometry and len(self.synthetic_photometry) > 0: # Set up hover tool - phot_tips = [('Band', '@desc'), ('Wave', '@x'), ('Flux', '@y'), ('Unc', '@z')] - hover = HoverTool(name='synthetic photometry', tooltips=phot_tips) + phot_tips = [('Band', '@{desc}'), ('Wave', '@{x}'), ('Flux', '@{y}'), ('Unc', '@{z}')] + hover = HoverTool(name='synthetic photometry', tooltips=phot_tips, mode='vline') + self.fig.add_tools(hover) # Plot points with errors @@ -2657,9 +2668,9 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic mod = mod_fit['full_model'] mod.wave_units = self.wave_units if mod_fit['fit_to'] == 'phot': - self.fig.square(mod.wave, mod.flux, alpha=0.3, color=color if one_color else next(col_list), legend_label=mod_fit['label'], name='best_fit_phot', size=12) + self.fig.square(mod.wave, mod.flux, alpha=0.9, color=color if one_color else next(col_list), legend_label=mod_fit['label'], size=12) else: - self.fig.line(mod.wave, mod.flux, alpha=0.3, color=color if one_color else next(col_list), legend_label=mod_fit['label'], name='best_fit_spec', line_width=2) + self.fig.line(mod.wave, mod.flux, alpha=0.9, color=color if one_color else next(col_list), legend_label=mod_fit['label'], line_width=2) self.fig.legend.location = "top_right" self.fig.legend.click_policy = "hide" @@ -2737,25 +2748,42 @@ def results(self): # Get the params rows = [] for param in params: - # Get the values and format attr = getattr(self, param, None) - if attr is None: attr = '--' - + # error = '$ ^{+0.004}_{-0.004}$' if isinstance(attr, (tuple, list)): - val, unc = attr[:2] - unit = val.unit if hasattr(val, 'unit') else '--' - val = val.value if hasattr(val, 'unit') else val - unc = unc.value if hasattr(unc, 'unit') else unc - if val < 1E-3 or val > 1e5: - val = float('{:.2e}'.format(val)) - if unc is None: - unc = '--' + if ((param == 'logg') or (param =='mass') or (param == 'radius') or (param == 'Teff')): + val = attr[0] + lower_err = attr[1] + upper_err = attr[2] + unit = val.unit if hasattr(val, 'unit') else '--' + val = val.value if hasattr(val, 'unit') else val + unc_l = lower_err.value * -1 if hasattr(lower_err, 'unit') else lower_err * -1 + unc_u = upper_err.value if hasattr(upper_err, 'unit') else upper_err + if val < 1E-3 or val > 1e5: + if param == 'mass': + val = float('{:.0f}'.format(val)) + unc = str('{:+.0f}'.format(unc_l) + ' ' + '{:+.0f}'.format(unc_u)) + else: + val = float('{:.2e}'.format(val)) + unc = str('{:-.2e}'.format(unc_l) + ' ' + '{:+.2e}'.format(unc_u)) else: - unc = float('{:.2e}'.format(unc)) - rows.append([param, val, unc, unit]) + unc = str('{:-.2f}'.format(unc_l) + ' ' + '{:+.2f}'.format(unc_u)) + rows.append([param, val, unc, unit]) + else: + val, unc = attr[:2] + unit = val.unit if hasattr(val, 'unit') else '--' + val = val.value if hasattr(val, 'unit') else val + unc = unc.value if hasattr(unc, 'unit') else unc + if val < 1E-3 or val > 1e5: + val = float('{:.2e}'.format(val)) + if unc is None: + unc = '--' + else: + unc = float('{:.2e}'.format(unc)) + rows.append([param, val, unc, unit]) elif isinstance(attr, (str, float, bytes, int)): rows.append([param, attr, '--', '--']) @@ -3029,7 +3057,6 @@ def _validate_and_set_param(self, param, values, units, set_uncalculated=True, t self.message("Setting {} to 'None'".format(param)) else: - # If the last value is string, it's the reference if isinstance(values[-1], str): ref = values[-1] @@ -3038,7 +3065,7 @@ def _validate_and_set_param(self, param, values, units, set_uncalculated=True, t ref = None # Make sure it's a sequence - if not u.issequence(values, length=[2, 3]): + if not u.issequence(values, length=[2, 3, 4, 5]): raise TypeError("{} must be a sequence of (value, error) or (value, lower_error, upper_error).".format(param)) # Make sure it's in correct units @@ -3046,8 +3073,11 @@ def _validate_and_set_param(self, param, values, units, set_uncalculated=True, t raise TypeError("{} values must be {}".format(param, 'unitless' if units is None else "astropy.units.quantity.Quantity of the appropriate units , e.g. '{}'".format(units))) # Ensure valid range but don't throw error + vmin = vmin if vmin is not None else -np.inf * (1 if units is None else units) vmax = vmax if vmax is not None else np.inf * (1 if units is None else units) + + if (values[0] < vmin) or (values[0] > vmax): self.message("{}: {} value is not in valid range [{}, {}].".format(values, param, vmin, vmax)) @@ -3125,9 +3155,9 @@ def __init__(self, **kwargs): self.find_SDSS() self.find_2MASS() self.find_WISE() - self.radius = 2.818 * q.Rsun, 0.008 * q.Rsun, '2010ApJ...708...71Y' + self.radius = 2.818 * q.Rsun, 0.008 * q.Rsun, 0.008 * q.Rsun, '2010ApJ...708...71Y' self.age = 455 * q.Myr, 13 * q.Myr, '2010ApJ...708...71Y' - self.logg = 4.1, 0.1, '2006ApJ...645..664A' + self.logg = 4.1, 0.1, 0.1, '2006ApJ...645..664A' # Get the spectrum self.add_spectrum(sp.Vega(snr=100)) diff --git a/sedkit/spectrum.py b/sedkit/spectrum.py index 4bbc4ca..6b0421e 100644 --- a/sedkit/spectrum.py +++ b/sedkit/spectrum.py @@ -562,7 +562,7 @@ def flux_units(self, flux_units): self._flux_units = flux_units self._set_units() - def integrate(self, units=None, n_samples=10): + def integrate(self, units=None, n_samples=1000): """Calculate the area under the spectrum Parameters @@ -619,7 +619,7 @@ def integrate(self, units=None, n_samples=10): uvals.append(abs(err.value - val.value)) # Get 1-sigma of distribution - vunc = np.max(abs(np.asarray(uvals))) * units + vunc = np.median(abs(np.asarray(uvals))) * units return val, vunc @@ -1057,9 +1057,9 @@ def resamp(self, wave=None, resolution=None, name=None): wave = wave.value # Bin the spectrum - print(wave, self.wave) + binned = u.spectres(wave, self.wave, self.flux, self.unc) - print(binned[0]) + # Update the spectrum spectrum = [i * Q for i, Q in zip(binned, self.units)] diff --git a/sedkit/tests/test_isochrone.py b/sedkit/tests/test_isochrone.py index 2b9cd3e..aec4c10 100644 --- a/sedkit/tests/test_isochrone.py +++ b/sedkit/tests/test_isochrone.py @@ -1,36 +1,62 @@ """Series of unit tests for the isochrone.py module""" import unittest - +import pytest import astropy.units as q +import numpy as np from .. import isochrone as iso from .. import utilities as u +@pytest.mark.parametrize('xval,age,xparam,yparam,expected_result,expected_result_low,expected_result_up', [ + # Mass + ((-4, 0.1), (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'mass', 0.072, 0.072, 0.072), # With uncertainties + (-4, (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'mass', 0.072, 0.072, 0.072), # No xparam uncertainty + ((-4, 0.1), 4 * q.Gyr, 'Lbol', 'mass', 0.072, 0.072, 0.072), # No yparam uncertainty + (-4, 4 * q.Gyr, 'Lbol', 'mass', 0.020, 0, 0.058), # No xparam and yparam uncertainties + # Radius + ((-4, 0.1), (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'radius', 0.095, 0.095, 0.095), # With uncertainties + (-4, (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'radius', 0.095, 0.095, 0.095), # No xparam uncertainty + ((-4, 0.1), 4 * q.Gyr, 'Lbol', 'radius', 0.095, 0.095, 0.095), # No yparam uncertainty + (-4, 4 * q.Gyr, 'Lbol', 'radius', 0.045, 0.01, 0.080), # No xparam and yparam uncertainties + # Logg + ((-4, 0.1), (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'logg', 5.345, 5.34, 5.35), # With uncertainties + (-4, (4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'logg', 5.345, 5.34, 5.35), # No xparam uncertainty + ((-4, 0.1), 4 * q.Gyr, 'Lbol', 'logg', 5.345, 5.34, 5.35), # No yparam uncertainty + (-4, 4 * q.Gyr, 'Lbol', 'logg', 5.395, 5.36, 5.43), # No xparam and yparam uncertaintiesd + # Young age + ((-4, 0.1), (0.4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'mass', 0.0515, 0.045, 0.055), # mass with uncertainties + ((-4, 0.1), (0.4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'logg', 5.08, 5.01, 5.15), # logg with uncertainties + ((-4, 0.1), (0.4 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'radius', 0.105, 0.10, 0.11), # radius with uncertainties + # Old age + ((-4, 0.1), (9 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'mass', 0.074, 0.070, 0.080), # mass with uncertainties + ((-4, 0.1), (9 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'logg', 5.345, 5.34, 5.35), # logg with uncertainties + ((-4, 0.1), (9 * q.Gyr, 0.1 * q.Gyr), 'Lbol', 'radius', 0.095, 0.09, 0.10) # radius with uncertainties +]) +def test_evaluate(xval, age, xparam, yparam, expected_result, expected_result_low, expected_result_up): + """Test the evaluate method""" + hsa = iso.Isochrone('hybrid_solar_age') + result = hsa.evaluate(xval, age, xparam, yparam) + average = result[0] # Average yparam value + lower = result[0] - result[1] # Lower yparam value + upper = result[0] + result[2] # Upper yparam value + assert (isinstance(result, tuple)) is True + if yparam == 'logg': + assert (np.isclose(average, expected_result, atol=0.005)) + assert (np.isclose(lower, expected_result_low, atol=0.01)) + assert (np.isclose(upper, expected_result_up, atol=0.01)) + else: + assert (np.isclose(average.value, expected_result, atol=0.005)) + assert (np.isclose(lower.value, expected_result_low, atol=0.01)) + assert (np.isclose(upper.value, expected_result_up, atol=0.01)) + + class TestIsochrone(unittest.TestCase): """Tests for the hybrid_solar_age model isochrones""" def setUp(self): # Make Spectrum class for testing self.hsa = iso.Isochrone('hybrid_solar_age') - def test_evaluate(self): - """Test the evaluate method""" - # With uncertainties - result = self.hsa.evaluate((-4, 0.1), (4*q.Gyr, 0.1*q.Gyr), 'Lbol', 'mass') - self.assertTrue(isinstance(result, tuple)) - - # No xparam uncertainty - result = self.hsa.evaluate(-4, (4*q.Gyr, 0.1*q.Gyr), 'Lbol', 'mass') - self.assertTrue(isinstance(result, tuple)) - - # No yparam uncertainty - result = self.hsa.evaluate((-4, 0.1), 4*q.Gyr, 'Lbol', 'mass') - self.assertTrue(isinstance(result, tuple)) - - # No xparam or yparam uncertainties - result = self.hsa.evaluate(-4, 4*q.Gyr, 'Lbol', 'mass') - self.assertTrue(isinstance(result, tuple) and result[1] == 0) - def test_interp(self): """Test that the model isochrone can be interpolated""" # Successful interpolation diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index de51af4..c08fdde 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -148,7 +148,7 @@ def test_no_spectra(self): """Test that a purely photometric SED can be creted""" s = copy.copy(self.sed) s.age = 455*q.Myr, 13*q.Myr - s.radius = 2.362*q.Rsun, 0.02*q.Rjup + s.radius = 2.362*q.Rsun, 0.02*q.Rjup, 0.02*q.Rjup s.parallax = 130.23*q.mas, 0.36*q.mas s.spectral_type = 'A0V' s.add_photometry('2MASS.J', -0.177, 0.206) @@ -185,10 +185,10 @@ def test_plot(self): fig = v.plot(integral=True, synthetic_photometry=True, best_fit=True) def test_no_photometry(self): - """Test that a purely photometric SED can be creted""" + """Test that a purely photometric SED can be created""" s = copy.copy(self.sed) s.age = 455*q.Myr, 13*q.Myr - s.radius = 2.362*q.Rsun, 0.02*q.Rjup + s.radius = 2.362*q.Rsun, 0.02*q.Rjup,0.02*q.Rjup s.parallax = 130.23*q.mas, 0.36*q.mas s.spectral_type = 'A0V' s.add_spectrum(self.spec1) diff --git a/sedkit/tests/test_utilities.py b/sedkit/tests/test_utilities.py index 2f438e6..3a54072 100644 --- a/sedkit/tests/test_utilities.py +++ b/sedkit/tests/test_utilities.py @@ -1,6 +1,6 @@ """A suite of tests for the utilities.py module""" import copy -from pkg_resources import resource_filename +import importlib_resources import pytest import unittest @@ -294,7 +294,7 @@ def test_group_spectra(): def test_spectrum_from_fits(): """Test spectrum_from_fits function""" # Get the file - f = resource_filename('sedkit', '/data/Gl752B_NIR.fits') + f = str(importlib_resources.files('sedkit')/'/data/Gl752B_NIR.fits') # Get the spectrum spec = u.spectrum_from_fits(f) diff --git a/sedkit/utilities.py b/sedkit/utilities.py index fea3b55..a3e181a 100644 --- a/sedkit/utilities.py +++ b/sedkit/utilities.py @@ -467,7 +467,6 @@ def flux2mag(flx, bandpass, photon=True): unit = flx.unit # Set uncertainty - unc = unc if unc is not None else np.nan * unit # Convert energy units to photon counts if photon: