Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/key oid candid alerts #378

Merged
merged 16 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion correction_step/correction/_step/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def produce_scribe(self, detections: list[dict]):
if not detection.pop("new"):
continue
candid = detection.pop("candid")
oid = detection.get("oid")
is_forced = detection.pop("forced")
set_on_insert = not detection.get("has_stamp", False)
extra_fields = detection["extra_fields"].copy()
Expand All @@ -150,7 +151,7 @@ def produce_scribe(self, detections: list[dict]):
scribe_data = {
"collection": "forced_photometry" if is_forced else "detection",
"type": "update",
"criteria": {"_id": candid},
"criteria": {"candid": candid, "oid": oid},
"data": detection,
"options": {"upsert": True, "set_on_insert": set_on_insert},
}
Expand Down
26 changes: 20 additions & 6 deletions correction_step/correction/core/corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ def __init__(self, detections: list[dict]):
"""
self.logger = logging.getLogger(f"alerce.{self.__class__.__name__}")
self._detections = pd.DataFrame.from_records(detections, exclude={"extra_fields"})
self._detections = self._detections.drop_duplicates("candid").set_index("candid")
self._detections = self._detections.drop_duplicates(["candid", "oid"]).set_index("candid")

self.__extras = {alert["candid"]: alert["extra_fields"] for alert in detections}
extras = pd.DataFrame.from_dict(self.__extras, orient="index", columns=self._EXTRA_FIELDS)
extras = extras.reset_index(names=["candid"]).drop_duplicates("candid").set_index("candid")
self.__extras = [
{**alert["extra_fields"], "candid": alert["candid"], "oid": alert["oid"]} for alert in detections
]
extras = pd.DataFrame(self.__extras, columns=self._EXTRA_FIELDS + ["candid", "oid"])
extras = extras.drop_duplicates(["candid", "oid"]).set_index("candid")

self._detections = self._detections.join(extras)
self._detections = self._detections.join(extras, how="left", rsuffix="_extra")
self._detections = self._detections.drop("oid_extra", axis=1)

def _survey_mask(self, survey: str):
"""Creates boolean mask of detections whose `sid` matches the given survey name (case-insensitive)
Expand Down Expand Up @@ -102,14 +105,25 @@ def corrected_as_records(self) -> list[dict]:
The records are a list of mappings with the original input pairs and the new pairs together.
"""

def find_extra_fields(oid, candid):
for extra in self.__extras:
if extra["oid"] == oid and extra["candid"] == candid:
result = {**extra}
result.pop("oid")
result.pop("candid")
return result
return None

self.logger.debug(f"Correcting {len(self._detections)} detections...")
corrected = self.corrected_magnitudes().replace(np.inf, self._ZERO_MAG)
corrected = corrected.assign(corrected=self.corrected, dubious=self.dubious, stellar=self.stellar)
corrected = self._detections.join(corrected).replace(np.nan, None).drop(columns=self._EXTRA_FIELDS)
corrected = corrected.replace(-np.inf, None)
self.logger.debug(f"Corrected {corrected['corrected'].sum()}")
corrected = corrected.reset_index().to_dict("records")
return [{**record, "extra_fields": self.__extras[record["candid"]]} for record in corrected]

return [{**record, "extra_fields": find_extra_fields(record["oid"], record["candid"])} for record in corrected]

@staticmethod
def weighted_mean(values: pd.Series, sigmas: pd.Series) -> float:
Expand Down
2 changes: 1 addition & 1 deletion correction_step/tests/integration/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def assert_scribe_has_detections(message):
data = json.loads(message["payload"])
assert "collection" in data and data["collection"] == "detection"
assert "type" in data and data["type"] == "update"
assert "criteria" in data and "_id" in data["criteria"]
assert "criteria" in data and "candid" in data["criteria"] and "oid" in data["criteria"]
assert "data" in data and len(data["data"]) > 0
assert "options" in data and "upsert" in data["options"] and "set_on_insert" in data["options"]
assert data["options"]["upsert"] is True
Expand Down
14 changes: 8 additions & 6 deletions correction_step/tests/unittests/test_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from correction import Corrector
from tests.utils import ztf_alert, atlas_alert

detections = [ztf_alert(candid="c1"), atlas_alert(candid="c2")]
detections = [ztf_alert(candid="c1", oid="oid_ztf"), atlas_alert(candid="c2", oid="oid_atlas")]
MAG_CORR_COLS = ["mag_corr", "e_mag_corr", "e_mag_corr_ext"]
ALL_NEW_COLS = MAG_CORR_COLS + ["dubious", "stellar", "corrected"]

Expand Down Expand Up @@ -178,11 +178,12 @@ def test_calculate_coordinates_ignores_forced_photometry():


def test_coordinates_dataframe_calculates_mean_for_each_aid():
corrector = Corrector(detections)
detections_duplicate = [ztf_alert(candid="c"), atlas_alert(candid="c")]
corrector = Corrector(detections_duplicate)
assert corrector.mean_coordinates().index == ["OID1"]

altered_detections = deepcopy(detections)
altered_detections[0]["oid"] = "OID2"
altered_detections = deepcopy(detections_duplicate)
altered_detections[0]["oid"] = "OID1"
corrector = Corrector(altered_detections)
assert corrector.mean_coordinates().index.isin(["OID1", "OID2"]).all()

Expand All @@ -193,10 +194,11 @@ def test_coordinates_dataframe_includes_mean_ra_and_mean_dec():


def test_coordinates_records_has_one_entry_per_aid():
corrector = Corrector(detections)
test_detections = [ztf_alert(candid="c1"), atlas_alert(candid="c2")]
corrector = Corrector(test_detections)
assert set(corrector.coordinates_as_records()) == {"OID1"}

altered_detections = deepcopy(detections)
altered_detections = deepcopy(test_detections)
altered_detections[0]["oid"] = "OID2"
corrector = Corrector(altered_detections)
assert set(corrector.coordinates_as_records()) == {"OID1", "OID2"}
Expand Down
3 changes: 2 additions & 1 deletion correction_step/tests/unittests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ def __init__(self):
data = {
"collection": "detection" if not det["forced"] else "forced_photometry",
"type": "update",
"criteria": {"_id": det["candid"]},
"criteria": {"candid": det["candid"], "oid": det["oid"]},
"data": {k: v for k, v in det.items() if k not in ["candid", "forced", "new"]},
"options": {"upsert": True, "set_on_insert": not det["has_stamp"]},
}
if count == len(message4execute_copy["detections"]):
flush = True

step.scribe_producer.produce.assert_any_call({"payload": json.dumps(data)}, flush=flush)


Expand Down
8 changes: 6 additions & 2 deletions feature_step/features/core/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@ def __init__(

if self.UNIQUE:
self._alerts.drop_duplicates(self.UNIQUE, inplace=True)
self.logger.debug(
f"{len(self._alerts)} {self._NAME} remain after unque removal"
)
if self.NON_DUPLICATE:
self._alerts.drop_duplicates(self.NON_DUPLICATE, inplace=True)
self.logger.debug(
f"{len(self._alerts)} {self._NAME} remain after duplicate removal"
)
if self.INDEX:
self._alerts.drop_duplicates(self.INDEX, inplace=True)
self._alerts.set_index(self.INDEX, inplace=True)
self.logger.debug(f"Using column(s) {self.INDEX} for indexing")

Expand Down Expand Up @@ -309,7 +313,7 @@ def __add_extra_fields(self, alerts: list[dict], extras: list[str]):
)
df = (
df.reset_index(names=[self.INDEX])
.drop_duplicates(self.INDEX)
.drop_duplicates(self.NON_DUPLICATE)
.set_index(self.INDEX)
)
self._alerts = self._alerts.join(df)
Expand Down
1 change: 1 addition & 0 deletions feature_step/features/core/handlers/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DetectionsHandler(BaseHandler):

_NAME = "detections"
INDEX = "candid"
NON_DUPLICATE = ["oid", "candid"]
UNIQUE = ["id", "fid", "mjd"]
COLUMNS = BaseHandler.COLUMNS + [
"mag",
Expand Down
1 change: 1 addition & 0 deletions feature_step/features/core/handlers/non_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class NonDetectionsHandler(BaseHandler):

_NAME = "non-detections"
UNIQUE = ["id", "fid", "mjd"]
NON_DUPLICATE = ["oid", "candid"]
COLUMNS = BaseHandler.COLUMNS + ["diffmaglim"]

def _post_process(self, **kwargs):
Expand Down
13 changes: 7 additions & 6 deletions feature_step/features/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_fid(feature_name: str):
command = {
"collection": "object",
"type": "update_features",
"criteria": {"_id": oid},
"criteria": {"_id": oid}, #esto esta mal, deberia ser oid: oid, candid:candid? habria que pasar el candid
"data": {
"features_version": extractor_class.VERSION,
"features_group": extractor_class.NAME,
Expand Down Expand Up @@ -71,7 +71,7 @@ def _parse_scribe_payload_ztf(features, extractor_class):
command = {
"collection": "object",
"type": "update_features",
"criteria": {"_id": oid},
"criteria": {"oid": oid},
"data": {
"features_version": extractor_class.VERSION,
"features_group": extractor_class.NAME,
Expand Down Expand Up @@ -128,7 +128,7 @@ def _parse_output_elasticc(features, alert_data, extractor_class, candids):
oid = message["oid"]
candid = candids[oid]
try:
features_dict = features.loc[oid].to_dict()
features_dict = features.loc[oid].iloc[0].to_dict()
except KeyError: # No feature for the object
logger = logging.getLogger("alerce")
logger.info("Could not calculate features of object %s", oid)
Expand Down Expand Up @@ -161,18 +161,19 @@ def _parse_output_ztf(features, alert_data, extractor_class, candids):
oid = message["oid"]
candid = candids[oid]
try:
features_dict = features.loc[oid].to_dict()
features_for_oid = features.loc[oid].to_dict()
features_for_oid = features_for_oid if isinstance(features_for_oid, dict) else features_for_oid[0]
except KeyError: # No feature for the object
logger = logging.getLogger("alerce")
logger.info("Could not calculate features of object %s", oid)
features_dict = None
features_for_oid = None
out_message = {
"oid": oid,
"candid": candid,
"detections": message["detections"],
"non_detections": message["non_detections"],
"xmatches": message["xmatches"],
"features": features_dict,
"features": features_for_oid,
}
output_messages.append(out_message)

Expand Down
Loading
Loading