Skip to content

Commit

Permalink
add ensemble default cols
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Sep 15, 2023
1 parent 3b269d3 commit 16e509e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
29 changes: 19 additions & 10 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,22 +1240,15 @@ def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwa
**kwargs,
)

def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", out_col_name=None):
def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux_col=None, err_col=None):
"""Converts a flux column into a magnitude column.
Parameters
----------
flux_col: 'str'
The name of the ensemble flux column to convert into magnitudes.
zero_point: 'str' or 'float'
The name of the ensemble column containing the zero point
information for column transformation. Alternatively, a float zero
point value to apply to all fluxes.
err_col: 'str', optional
The name of the ensemble column containing the errors to propagate.
Errors are propagated using the following approximation:
Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the
error in flux is much smaller than the flux.
information for column transformation. Alternatively, a single
float number to apply for all fluxes.
zp_form: `str`, optional
The form of the zero point column, either "flux" or
"magnitude"/"mag". Determines how the zero point (zp) is applied in
Expand All @@ -1266,13 +1259,29 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag",
The name of the output magnitude column, if None then the output
is just the flux column name + "_mag". The error column is also
generated as the out_col_name + "_err".
flux_col: 'str', optional
The name of the ensemble flux column to convert into magnitudes.
Uses the Ensemble mapped flux column if not specified.
err_col: 'str', optional
The name of the ensemble column containing the errors to propagate.
Errors are propagated using the following approximation:
Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the
error in flux is much smaller than the flux. Uses the Ensemble
mapped error column if not specified.
Returns
----------
ensemble: `tape.ensemble.Ensemble`
The ensemble object with a new magnitude (and error) column.
"""

# Assign Ensemble cols if not provided
if flux_col is None:
flux_col = self._flux_col
if err_col is None:
err_col = self._err_col

if out_col_name is None:
out_col_name = flux_col + "_mag"

Expand Down
23 changes: 8 additions & 15 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,8 @@ def test_coalesce(dask_client, drop_inputs):

@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)])
@pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"])
@pytest.mark.parametrize("err_col", [None, "error"])
@pytest.mark.parametrize("out_col_name", [None, "mag"])
def test_convert_flux_to_mag(dask_client, zero_point, zp_form, err_col, out_col_name):
def test_convert_flux_to_mag(dask_client, zero_point, zp_form, out_col_name):
ens = Ensemble(client=dask_client)

source_dict = {
Expand All @@ -771,32 +770,26 @@ def test_convert_flux_to_mag(dask_client, zero_point, zp_form, err_col, out_col_
ens.from_source_dict(source_dict, column_mapper=col_map)

if zp_form == "flux":
ens.convert_flux_to_mag("flux", zero_point[1], err_col, zp_form, out_col_name)
ens.convert_flux_to_mag(zero_point[1], zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925

if err_col is not None:
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979
else:
assert output_column + "_err" not in ens._source.columns
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979

elif zp_form == "mag" or zp_form == "magnitude":
ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, out_col_name)
ens.convert_flux_to_mag(zero_point[0], zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925

if err_col is not None:
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979
else:
assert output_column + "_err" not in ens._source.columns
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979

else:
with pytest.raises(ValueError):
ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, "mag")
ens.convert_flux_to_mag(zero_point[0], zp_form, "mag")


def test_find_day_gap_offset(dask_client):
Expand Down

0 comments on commit 16e509e

Please sign in to comment.