Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(magstats-step): calculate deltajd #307

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/db-plugins/db_plugins/db/mongo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Object(BaseModel):
stellar = Field()
firstmjd = Field()
lastmjd = Field()
deltajd = Field()
ndet = Field()
meanra = Field()
sigmara = Field()
Expand Down
2 changes: 2 additions & 0 deletions libs/db-plugins/tests/unittest/db/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def create_2_objects(self):
stellar=False,
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1.0,
meanra=100.0,
sigmara=0.1,
meandec=50.0,
Expand Down Expand Up @@ -88,6 +89,7 @@ def create_2_objects(self):
sigmadec=0.1,
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1.0,
meanra=100.0,
meandec=50.0,
ndet=5,
Expand Down
15 changes: 10 additions & 5 deletions libs/db-plugins/tests/unittest/db/test_mongo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ def test_object_creates(self):
sid="sid",
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1,
corrected=True,
stellar=True,
sigmara=.1,
sigmadec=.2,
sigmara=0.1,
sigmadec=0.2,
meanra=100.0,
meandec=50.0,
ndet="ndet",
Expand Down Expand Up @@ -67,7 +68,9 @@ def test_detection_creates(self):
self.assertEqual(d["aid"], "aid")

def test_detection_fails_creation(self):
with self.assertRaisesRegex(AttributeError, "Detection model needs .+? attribute"):
with self.assertRaisesRegex(
AttributeError, "Detection model needs .+? attribute"
):
models.Detection()

def test_detection_with_extra_fields(self):
Expand Down Expand Up @@ -152,7 +155,7 @@ def test_forced_photometry_creates(self):
parent_candid="parent_candid",
has_stamp="has_stamp",
rbversion="rbversion",
extra_fields={},
extra_fields={},
)
self.assertIsInstance(fp, models.ForcedPhotometry)
self.assertIsInstance(fp, dict)
Expand All @@ -174,7 +177,9 @@ def test_non_detection_creates(self):
self.assertEqual(o["aid"], "aid")

def test_non_detection_fails_creation(self):
with self.assertRaisesRegex(AttributeError, "NonDetection model needs .+ attribute"):
with self.assertRaisesRegex(
AttributeError, "NonDetection model needs .+ attribute"
):
models.NonDetection()
# self.assertEqual(str(e.exception), "NonDetection model needs aid attribute")

Expand Down
15 changes: 9 additions & 6 deletions libs/db-plugins/tests/unittest/db/test_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def create_2_objects(self):
sid="sid",
corrected=True,
stellar=True,
sigmara=.1,
sigmadec=.1,
sigmara=0.1,
sigmadec=0.1,
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1,
meanra=100.0,
meandec=50.0,
ndet=2,
Expand Down Expand Up @@ -54,10 +55,11 @@ def create_2_objects(self):
tid="tid2",
corrected=True,
stellar=True,
sigmara=.1,
sigmadec=.1,
sigmara=0.1,
sigmadec=0.1,
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1,
meanra=100.0,
meandec=50.0,
ndet=5,
Expand Down Expand Up @@ -88,10 +90,11 @@ def create_simple_object(self):
tid="tid3",
corrected=True,
stellar=True,
sigmara=.1,
sigmadec=.1,
sigmara=0.1,
sigmadec=0.1,
lastmjd="lastmjd",
firstmjd="firstmjd",
deltajd=1,
meanra=100.0,
meandec=50.0,
ndet=5,
Expand Down
57 changes: 45 additions & 12 deletions magstats_step/magstats_step/core/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ class BaseStatistics(abc.ABC):

def __init__(self, detections: List[dict]):
try:
self._detections = pd.DataFrame.from_records(detections, exclude=["extra_fields"])
self._detections = pd.DataFrame.from_records(
detections, exclude=["extra_fields"]
)
except KeyError: # extra_fields is not present
self._detections = pd.DataFrame.from_records(detections)
self._detections = self._detections.drop_duplicates("candid").set_index("candid")
self._detections = self._detections.drop_duplicates(
"candid"
).set_index("candid")
# Select only non-forced detections
self._detections = self._detections[~self._detections["forced"]]

@classmethod
def _group(cls, df: Union[pd.DataFrame, pd.Series]) -> Union[DataFrameGroupBy, SeriesGroupBy]:
def _group(
cls, df: Union[pd.DataFrame, pd.Series]
) -> Union[DataFrameGroupBy, SeriesGroupBy]:
return df.groupby(cls._JOIN)

@lru_cache(10)
Expand All @@ -40,8 +46,14 @@ def _surveys_mask(self, surveys: Tuple[str] = None) -> pd.Series:
return pd.Series(True, index=self._detections.index)

@lru_cache(6)
def _select_detections(self, *, surveys: Tuple[str] = None, corrected: bool = False) -> pd.Series:
mask = self._detections["corrected"] if corrected else pd.Series(True, index=self._detections.index)
def _select_detections(
self, *, surveys: Tuple[str] = None, corrected: bool = False
) -> pd.Series:
mask = (
self._detections["corrected"]
if corrected
else pd.Series(True, index=self._detections.index)
)
return self._detections[self._surveys_mask(surveys) & mask]

@lru_cache(12)
Expand All @@ -58,7 +70,9 @@ def _grouped_index(
function = "idxmax"
else:
raise ValueError(f"Unrecognized value for 'which': {which}")
return self._grouped_detections(surveys=surveys, corrected=corrected)["mjd"].agg(function)
return self._grouped_detections(surveys=surveys, corrected=corrected)[
"mjd"
].agg(function)

@lru_cache(36)
def _grouped_value(
Expand All @@ -69,24 +83,43 @@ def _grouped_value(
surveys: Tuple[str] = None,
corrected: bool = False,
) -> pd.Series:
idx = self._grouped_index(which=which, surveys=surveys, corrected=corrected)
idx = self._grouped_index(
which=which, surveys=surveys, corrected=corrected
)
df = self._select_detections(surveys=surveys, corrected=corrected)
return df[column][idx].set_axis(idx.index)

@lru_cache(6)
def _grouped_detections(self, *, surveys: Tuple[str] = None, corrected: bool = False) -> DataFrameGroupBy:
return self._group(self._select_detections(surveys=surveys, corrected=corrected))
def _grouped_detections(
self, *, surveys: Tuple[str] = None, corrected: bool = False
) -> DataFrameGroupBy:
return self._group(
self._select_detections(surveys=surveys, corrected=corrected)
)

def calculate_ndet(self) -> pd.DataFrame:
return pd.DataFrame({"ndet": self._detections.value_counts(subset=self._JOIN, sort=False)})
return pd.DataFrame(
{
"ndet": self._detections.value_counts(
subset=self._JOIN, sort=False
)
}
)

def generate_statistics(self, exclude: Set[str] = None) -> pd.DataFrame:
exclude = exclude or set() # Empty default
# Add prefix to exclude, unless already provided
exclude = {name if name.startswith(self._PREFIX) else f"{self._PREFIX}{name}" for name in exclude}
exclude = {
name if name.startswith(self._PREFIX) else f"{self._PREFIX}{name}"
for name in exclude
}

# Select all methods that start with prefix unless excluded
methods = {name for name in dir(self) if name.startswith(self._PREFIX) and name not in exclude}
methods = {
name
for name in dir(self)
if name.startswith(self._PREFIX) and name not in exclude
}

# Compute all statistics and join into single dataframe
stats = [getattr(self, method)() for method in methods]
Expand Down
78 changes: 59 additions & 19 deletions magstats_step/magstats_step/core/magstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ class MagnitudeStatistics(BaseStatistics):
# Saturation threshold for each survey (only applies to corrected magnitudes)
_THRESHOLD = {"ZTF": 13.2}

def __init__(self, detections: List[dict], non_detections: List[dict] = None):
def __init__(
self, detections: List[dict], non_detections: List[dict] = None
):
super().__init__(detections)
if non_detections:
self._non_detections = pd.DataFrame.from_records(non_detections).drop_duplicates(["oid", "fid", "mjd"])
self._non_detections = pd.DataFrame.from_records(
non_detections
).drop_duplicates(["oid", "fid", "mjd"])
else:
self._non_detections = pd.DataFrame()

Expand All @@ -29,44 +33,68 @@ def _calculate_stats(self, corrected: bool = False) -> pd.DataFrame:
stats = grouped[in_label].agg(**functions)
# Pandas std requires additional kwarg, that's why it needs to be added apart
return stats.join(
grouped[in_label].agg("std", ddof=0).rename(out_label.format("sigma")),
grouped[in_label]
.agg("std", ddof=0)
.rename(out_label.format("sigma")),
how="outer",
)

def _calculate_stats_over_time(self, corrected: bool = False) -> pd.DataFrame:
def _calculate_stats_over_time(
self, corrected: bool = False
) -> pd.DataFrame:
suffix = "_corr" if corrected else ""
in_label, out_label = f"mag{suffix}", f"mag{{}}{suffix}"

first = self._grouped_value(in_label, which="first", corrected=corrected)
first = self._grouped_value(
in_label, which="first", corrected=corrected
)
last = self._grouped_value(in_label, which="last", corrected=corrected)
return pd.DataFrame({out_label.format("first"): first, out_label.format("last"): last})
return pd.DataFrame(
{out_label.format("first"): first, out_label.format("last"): last}
)

def calculate_statistics(self) -> pd.DataFrame:
stats = self._calculate_stats(corrected=False)
stats = stats.join(self._calculate_stats_over_time(corrected=False), how="outer")
stats = stats.join(
self._calculate_stats_over_time(corrected=False), how="outer"
)
stats = stats.join(self._calculate_stats(corrected=True), how="outer")
return stats.join(self._calculate_stats_over_time(corrected=True), how="outer")
return stats.join(
self._calculate_stats_over_time(corrected=True), how="outer"
)

def calculate_firstmjd(self) -> pd.DataFrame:
return pd.DataFrame({"firstmjd": self._grouped_value("mjd", which="first")})
return pd.DataFrame(
{"firstmjd": self._grouped_value("mjd", which="first")}
)

def calculate_lastmjd(self) -> pd.DataFrame:
return pd.DataFrame({"lastmjd": self._grouped_value("mjd", which="last")})
return pd.DataFrame(
{"lastmjd": self._grouped_value("mjd", which="last")}
)

def calculate_corrected(self) -> pd.DataFrame:
return pd.DataFrame({"corrected": self._grouped_value("corrected", which="first")})
return pd.DataFrame(
{"corrected": self._grouped_value("corrected", which="first")}
)

def calculate_stellar(self) -> pd.DataFrame:
return pd.DataFrame({"stellar": self._grouped_value("stellar", which="first")})
return pd.DataFrame(
{"stellar": self._grouped_value("stellar", which="first")}
)

def calculate_ndubious(self) -> pd.DataFrame:
return pd.DataFrame({"ndubious": self._grouped_detections()["dubious"].sum()})
return pd.DataFrame(
{"ndubious": self._grouped_detections()["dubious"].sum()}
)

def calculate_saturation_rate(self) -> pd.DataFrame:
total = self._grouped_detections()["corrected"].sum()
saturated = pd.Series(index=total.index, dtype=float)
for survey, threshold in self._THRESHOLD.items():
sat = self._grouped_detections(surveys=(survey,))["mag_corr"].agg(lambda x: (x < threshold).sum())
sat = self._grouped_detections(surveys=(survey,))["mag_corr"].agg(
lambda x: (x < threshold).sum()
)
saturated.loc[sat.index] = sat

rate = np.where(total.ne(0), saturated.astype(float) / total, np.nan)
Expand All @@ -76,24 +104,36 @@ def calculate_dmdt(self) -> pd.DataFrame:
dt_min = 0.5

if self._non_detections.size == 0: # Handle no non-detection case
return pd.DataFrame(columns=["dt_first", "dm_first", "sigmadm_first", "dmdt_first"])
return pd.DataFrame(
columns=["dt_first", "dm_first", "sigmadm_first", "dmdt_first"]
)

first_mag = self._grouped_value("mag", which="first")
first_e_mag = self._grouped_value("e_mag", which="first")
first_mjd = self._grouped_value("mjd", which="first")

nd = self._non_detections.set_index(self._JOIN) # Index by join to compute based on it
nd = self._non_detections.set_index(
self._JOIN
) # Index by join to compute based on it

dt = first_mjd - nd["mjd"]
dm = first_mag - nd["diffmaglim"]
sigmadm = first_e_mag - nd["diffmaglim"]
dmdt = (first_mag + first_e_mag - nd["diffmaglim"]) / dt

# Include back fid for grouping and unique identification
results = pd.DataFrame({"dt": dt, "dm": dm, "sigmadm": sigmadm, "dmdt": dmdt}).reset_index()
results = pd.DataFrame(
{"dt": dt, "dm": dm, "sigmadm": sigmadm, "dmdt": dmdt}
).reset_index()
# Only include non-detections before dt_min
idx = self._group(results[results["dt"] > dt_min])["dmdt"].idxmin().dropna()
idx = (
self._group(results[results["dt"] > dt_min])["dmdt"]
.idxmin()
.dropna()
)

# Drop NaN, since they result from no non-detection before first detection
results = results.dropna().loc[idx].set_index(self._JOIN)
return results.rename(columns={c: f"{c}_first" for c in results.columns})
return results.rename(
columns={c: f"{c}_first" for c in results.columns}
)
Loading