Skip to content

Commit

Permalink
add zero_point as float input
Browse files Browse the repository at this point in the history
:q
q
  • Loading branch information
dougbrn committed Sep 15, 2023
1 parent 897a5cf commit 3b269d3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
28 changes: 19 additions & 9 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,9 +1247,10 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag",
----------
flux_col: 'str'
The name of the ensemble flux column to convert into magnitudes.
zero_point: 'str'
zero_point: 'str' or 'float'
The name of the ensemble column containing the zero point
information for column transformation.
information for column transformation. Alternatively, a float zero
point value to apply to all fluxes.
err_col: 'str', optional
The name of the ensemble column containing the errors to propagate.
Errors are propagated using the following approximation:
Expand All @@ -1276,15 +1277,24 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag",
out_col_name = flux_col + "_mag"

if zp_form == "flux": # mag = -2.5*np.log10(flux/zp)
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])}
)
if isinstance(zero_point, str):
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])}
)
else:
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)}
)

elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]}
)

if isinstance(zero_point, str):
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]}
)
else:
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point}
)
else:
raise ValueError(f"{zp_form} is not a valid zero_point format.")

Expand Down
11 changes: 6 additions & 5 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_from_parquet(data_fixture, request):
# Check to make sure the critical quantity labels are bound to real columns
assert parquet_ensemble._source[col] is not None


@pytest.mark.parametrize(
"data_fixture",
[
Expand Down Expand Up @@ -103,7 +104,6 @@ def test_from_dataframe(data_fixture, request):
assert ens._source[col] is not None



def test_available_datasets(dask_client):
"""
Test that the ensemble is able to successfully read in the list of available TAPE datasets
Expand Down Expand Up @@ -744,10 +744,11 @@ def test_coalesce(dask_client, drop_inputs):
assert col in ens._source.columns


@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)])
@pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"])
@pytest.mark.parametrize("err_col", [None, "error"])
@pytest.mark.parametrize("out_col_name", [None, "mag"])
def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):
def test_convert_flux_to_mag(dask_client, zero_point, zp_form, err_col, out_col_name):
ens = Ensemble(client=dask_client)

source_dict = {
Expand All @@ -770,7 +771,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):
ens.from_source_dict(source_dict, column_mapper=col_map)

if zp_form == "flux":
ens.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name)
ens.convert_flux_to_mag("flux", zero_point[1], err_col, zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925
Expand All @@ -782,7 +783,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):
assert output_column + "_err" not in ens._source.columns

elif zp_form == "mag" or zp_form == "magnitude":
ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, out_col_name)
ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925
Expand All @@ -795,7 +796,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):

else:
with pytest.raises(ValueError):
ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, "mag")
ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, "mag")


def test_find_day_gap_offset(dask_client):
Expand Down

0 comments on commit 3b269d3

Please sign in to comment.