diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 9bf4c6c9..0ba059b8 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1240,21 +1240,15 @@ def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwa **kwargs, ) - def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", out_col_name=None): + def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux_col=None, err_col=None): """Converts a flux column into a magnitude column. Parameters ---------- - flux_col: 'str' - The name of the ensemble flux column to convert into magnitudes. - zero_point: 'str' + zero_point: 'str' or 'float' The name of the ensemble column containing the zero point - information for column transformation. - err_col: 'str', optional - The name of the ensemble column containing the errors to propagate. - Errors are propagated using the following approximation: - Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the - error in flux is much smaller than the flux. + information for column transformation. Alternatively, a single + float number to apply for all fluxes. zp_form: `str`, optional The form of the zero point column, either "flux" or "magnitude"/"mag". Determines how the zero point (zp) is applied in @@ -1265,6 +1259,15 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", The name of the output magnitude column, if None then the output is just the flux column name + "_mag". The error column is also generated as the out_col_name + "_err". + flux_col: 'str', optional + The name of the ensemble flux column to convert into magnitudes. + Uses the Ensemble mapped flux column if not specified. + err_col: 'str', optional + The name of the ensemble column containing the errors to propagate. + Errors are propagated using the following approximation: + Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the + error in flux is much smaller than the flux. Uses the Ensemble + mapped error column if not specified. Returns ---------- @@ -1272,19 +1275,35 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", The ensemble object with a new magnitude (and error) column. """ + + # Assign Ensemble cols if not provided + if flux_col is None: + flux_col = self._flux_col + if err_col is None: + err_col = self._err_col + if out_col_name is None: out_col_name = flux_col + "_mag" if zp_form == "flux": # mag = -2.5*np.log10(flux/zp) - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} - ) + if isinstance(zero_point, str): + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} + ) + else: + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)} + ) elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} - ) - + if isinstance(zero_point, str): + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} + ) + else: + self._source = self._source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point} + ) else: raise ValueError(f"{zp_form} is not a valid zero_point format.") diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index ba330f19..5ea9d225 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -67,6 +67,7 @@ def test_from_parquet(data_fixture, request): # Check to make sure the critical quantity labels are bound to real columns assert parquet_ensemble._source[col] is not None + @pytest.mark.parametrize( "data_fixture", [ @@ -108,7 +109,6 @@ def test_from_dataframe(data_fixture, request): amplitude = ens.batch(calc_stetson_J) assert len(amplitude) == 5 - def test_available_datasets(dask_client): """ Test that the ensemble is able to successfully read in the list of available TAPE datasets @@ -749,10 +749,10 @@ def test_coalesce(dask_client, drop_inputs): assert col in ens._source.columns +@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)]) @pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"]) -@pytest.mark.parametrize("err_col", [None, "error"]) @pytest.mark.parametrize("out_col_name", [None, "mag"]) -def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): +def test_convert_flux_to_mag(dask_client, zero_point, zp_form, out_col_name): ens = Ensemble(client=dask_client) source_dict = { @@ -775,32 +775,26 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name): ens.from_source_dict(source_dict, column_mapper=col_map) if zp_form == "flux": - ens.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name) + ens.convert_flux_to_mag(zero_point[1], zp_form, out_col_name) res_mag = ens._source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - if err_col is not None: - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] - assert pytest.approx(res_err, 0.001) == 0.355979 - else: - assert output_column + "_err" not in ens._source.columns + res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + assert pytest.approx(res_err, 0.001) == 0.355979 elif zp_form == "mag" or zp_form == "magnitude": - ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, out_col_name) + ens.convert_flux_to_mag(zero_point[0], zp_form, out_col_name) res_mag = ens._source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - if err_col is not None: - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] - assert pytest.approx(res_err, 0.001) == 0.355979 - else: - assert output_column + "_err" not in ens._source.columns + res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + assert pytest.approx(res_err, 0.001) == 0.355979 else: with pytest.raises(ValueError): - ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, "mag") + ens.convert_flux_to_mag(zero_point[0], zp_form, "mag") def test_find_day_gap_offset(dask_client):