diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 47fc01b7..5f17a0ab 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1247,9 +1247,10 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", ---------- flux_col: 'str' The name of the ensemble flux column to convert into magnitudes. - zero_point: 'str' + zero_point: 'str' or 'float' The name of the ensemble column containing the zero point - information for column transformation. + information for column transformation. Alternatively, a float zero + point value to apply to all fluxes. err_col: 'str', optional The name of the ensemble column containing the errors to propagate. Errors are propagated using the following approximation: @@ -1276,15 +1277,24 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", out_col_name = flux_col + "_mag" if zp_form == "flux": # mag = -2.5*np.log10(flux/zp) - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} - ) + if isinstance(zero_point, str): + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} + ) + else: + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)} + ) elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} - ) - + if isinstance(zero_point, str): + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} + ) + else: + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point} + ) else: raise ValueError(f"{zp_form} is not a valid zero_point format.") diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 65d8b2a4..0ee55f7a 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -67,6 +67,7 @@ def test_from_parquet(data_fixture, request): # Check to make sure the critical quantity labels are bound to real columns assert parquet_ensemble._source[col] is not None + @pytest.mark.parametrize( "data_fixture", [ @@ -103,7 +104,6 @@ def test_from_dataframe(data_fixture, request): assert ens._source[col] is not None - def test_available_datasets(dask_client): """ Test that the ensemble is able to successfully read in the list of available TAPE datasets @@ -744,10 +744,11 @@ def test_coalesce(dask_client, drop_inputs): assert col in ens._source.columns +@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)]) @pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"]) @pytest.mark.parametrize("err_col", [None, "error"]) @pytest.mark.parametrize("out_col_name", [None, "mag"]) -def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): +def test_convert_flux_to_mag(dask_client, zero_point, zp_form, err_col, out_col_name): ens = Ensemble(client=dask_client) source_dict = { @@ -770,7 +771,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): ens.from_source_dict(source_dict, column_mapper=col_map) if zp_form == "flux": - ens.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name) + ens.convert_flux_to_mag("flux", zero_point[1], err_col, zp_form, out_col_name) res_mag = ens._source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 @@ -782,7 +783,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): assert output_column + "_err" not in ens._source.columns elif zp_form == "mag" or zp_form == "magnitude": - ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, out_col_name) + ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, out_col_name) res_mag = ens._source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 @@ -795,7 +796,7 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): else: with pytest.raises(ValueError): - ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, "mag") + ens.convert_flux_to_mag("flux", zero_point[0], err_col, zp_form, "mag") def test_find_day_gap_offset(dask_client):