Skip to content

Commit

Permalink
Fixed relation evaluate function for case with error bars
Browse files Browse the repository at this point in the history
  • Loading branch information
hover2pi committed Sep 20, 2024
1 parent ec55146 commit 539ef68
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
21 changes: 13 additions & 8 deletions sedkit/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,14 @@ def evaluate(self, rel_name, x_val, xunits=None, yunits=None, fit_local=False, p

if plot:
plt = self.plot(rel_name, xunits=xunits, yunits=yunits)
plt.circle([x_val.value if hasattr(x_val, 'unit') else x_val], [y_val.value if hasattr(y_val, 'unit') else y_val], color='red', size=10, legend_label='{}({})'.format(rel['yparam'], x_val))
if y_upper:
plt.line([x_val, x_val], [y_val - y_lower, y_val + y_upper], color='red')
xv = x_val.value if hasattr(x_val, 'unit') else x_val
yv = y_val.value if hasattr(y_val, 'unit') else y_val
plt.scatter([xv], [yv], color='red', size=10, legend_label='{}({})'.format(rel['yparam'], x_val))
print(y_val, y_upper, y_lower)
if y_upper is not None:
yvl = y_lower.value if hasattr(y_lower, 'unit') else y_lower
yvu = y_upper.value if hasattr(y_upper, 'unit') else y_upper
plt.line([xv, xv], [yv - yvl, yv + yvu], color='red')
show(plt)

# Restore full relation
Expand Down Expand Up @@ -362,7 +367,7 @@ def plot(self, rel_name, xunits=None, yunits=None, **kwargs):
fig.yaxis.axis_label = '{}{}'.format(yparam, '[{}]'.format(yunits or rel['yunit']))

# Draw points
fig.circle(x * xu, y * yu, legend_label='Data', **kwargs)
fig.scatter(x * xu, y * yu, legend_label='Data', **kwargs)

return fig

Expand Down Expand Up @@ -502,10 +507,10 @@ def generate(self, orders):
# ====================================================================

# Get the data
cat1 = V.query_constraints('J/ApJ/810/158/table1')[0]
cat2 = V.query_constraints('J/ApJ/810/158/table9')[0]
cat1 = V.query_constraints(catalog='J/ApJ/810/158/table1')[0]
cat2 = V.query_constraints(catalog='J/ApJ/810/158/table9')[0]

# Join the tables to getthe spectral types and radii in one table
# Join the tables to get the spectral types and radii in one table
mlty_data = at.join(cat1, cat2, keys='ID', join_type='outer')

# Only keep field age
Expand Down Expand Up @@ -619,7 +624,7 @@ def plot(self, draw=False):

# Add the data
if n == 0:
fig.circle(data['data']['spt'], data['data']['radius'], size=8,
fig.scatter(data['data']['spt'], data['data']['radius'], size=8,
color=color, legend_label=data['ref'])
else:
fig.square(data['data']['spt'], data['data']['radius'], size=8,
Expand Down
4 changes: 2 additions & 2 deletions sedkit/tests/test_relations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from pkg_resources import resource_filename
import importlib_resources
import unittest

import astropy.units as q
Expand Down Expand Up @@ -50,7 +50,7 @@ class TestRelation(unittest.TestCase):
"""Tests for the Relation base class"""
def setUp(self):
# Set the file
self.file = resource_filename('sedkit', 'data/dwarf_sequence.txt')
self.file = str(importlib_resources.files('sedkit')/ 'data/dwarf_sequence.txt')

def test_init(self):
"""Test class initialization"""
Expand Down

0 comments on commit 539ef68

Please sign in to comment.