Skip to content

Commit

Permalink
refactor: added oid in command's criteria field
Browse files Browse the repository at this point in the history
  • Loading branch information
pgallardor committed Oct 2, 2023
1 parent 27c72f1 commit 251d4f1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_scribe_messages(classifications_by_classifier: pd.DataFrame):
command = {
"collection": "object",
"type": "update_probabilities",
"criteria": {"_id": idx},
"criteria": {"_id": idx, "oid": kwargs["oids"].get(idx, [])},
"data": {
"classifier_name": row["classifier_name"],
"classifier_version": kwargs["classifier_version"],
Expand All @@ -88,6 +88,7 @@ def get_scribe_messages(classifications_by_classifier: pd.DataFrame):
}
for class_name in class_names:
command["data"].update({class_name: row[class_name]})
print(command)
commands.append(command)
return classifications_by_classifier

Expand Down
22 changes: 15 additions & 7 deletions lc_classification_step/lc_classification/core/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,26 @@ def execute(self, messages):
self.logger.debug("Messages received:\n", messages)
self.logger.info("Getting batch alert data")
model_input = create_input_dto(messages)
self.oids = {}
forced = []
prv_candidates = []
dia_objet = []
dia_object = []
for det in model_input._detections._value.iterrows():
if det[1]["forced"]:
forced.append(det[0])
if "diaObjet" in det[1].index:
dia_objet.append(det[0])
if "diaObject" in det[1].index:
dia_object.append(det[0])
if det[1]["parent_candid"] is not None:
prv_candidates.append(det[0])
if "diaObjet" in det[1].index:
dia_objet.append(det[0])
if "diaObject" in det[1].index:
dia_object.append(det[0])
# oid hack for ztf
if self.isztf:
oids = self.oids.get(det[0], [])
if det[1]["oid"] not in oids:
oids.append(det[1]["oid"])
self.oids[det[0]] = oids

if not self.model.can_predict(model_input):
self.logger.info("No data to process")
return (
Expand All @@ -117,7 +125,7 @@ def execute(self, messages):
f"The prv candidates detections are: {prv_candidates}"
)
self.logger.debug(
f"The aids for detections that are forced photometry or prv candidates and do not have the diaObjet field are:{dia_objet}"
f"The aids for detections that are forced photometry or prv candidates and do not have the diaObjet field are:{dia_object}"
)
self.logger.info(
"The number of features is: %i", len(model_input.features)
Expand Down Expand Up @@ -148,7 +156,7 @@ def execute(self, messages):

def post_execute(self, result: Tuple[OutputDTO, List[dict]]):
parsed_result = self.scribe_parser.parse(
result[0], classifier_version=self.classifier_version
result[0], classifier_version=self.classifier_version, oids=self.oids
)
self.produce_scribe(parsed_result.value)
return result

0 comments on commit 251d4f1

Please sign in to comment.