Skip to content

Commit

Permalink
Merge pull request #366 from alercebroker/refactor/magstat_using_oid
Browse files Browse the repository at this point in the history
Magstats step using oid
  • Loading branch information
dirodriguezm authored Jan 2, 2024
2 parents 8f3e200 + 7d210ea commit eecf416
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 100 deletions.
2 changes: 1 addition & 1 deletion magstats_step/magstats_step/core/magstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class MagnitudeStatistics(BaseStatistics):
_JOIN = ["aid", "sid", "fid"]
_JOIN = ["oid", "sid", "fid"]
# Saturation threshold for each survey (only applies to corrected magnitudes)
_THRESHOLD = {"ZTF": 13.2}

Expand Down
4 changes: 2 additions & 2 deletions magstats_step/magstats_step/core/objstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class ObjectStatistics(BaseStatistics):
_JOIN = "aid"
_JOIN = "oid"

def __init__(self, detections: List[dict]):
super().__init__(detections)
Expand Down Expand Up @@ -47,7 +47,7 @@ def average(series): # Needs wrapper to use the sigmas in the agg call
return self._weighted_mean(series, sigmas.loc[series.index])

sigmas = self._arcsec2deg(self._detections[f"e_{label}"])
grouped_sigmas = self._group(sigmas.set_axis(self._detections["aid"]))
grouped_sigmas = self._group(sigmas.set_axis(self._detections["oid"]))
return pd.DataFrame(
{
f"mean{label}": self._grouped_detections()[label].agg(average),
Expand Down
22 changes: 11 additions & 11 deletions magstats_step/magstats_step/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def _execute(self, messages: dict):
magstats = magstats_calculator.generate_statistics(
self.excluded
).reset_index()
magstats = magstats.set_index("aid").replace({np.nan: None})
for aid in stats:
magstats = magstats.set_index("oid").replace({np.nan: None})
for oid in stats:
try:
stats[aid]["magstats"] = magstats.loc[aid].to_dict("records")
stats[oid]["magstats"] = magstats.loc[oid].to_dict("records")
except TypeError:
stats[aid]["magstats"] = [magstats.loc[aid].to_dict()]
stats[oid]["magstats"] = [magstats.loc[oid].to_dict()]

return stats

Expand All @@ -63,12 +63,12 @@ def _execute_ztf(self, messages: dict):
magstats = magstats_calculator.generate_statistics(
self.excluded
).reset_index()
magstats = magstats.set_index("aid").replace({np.nan: None})
for aid in stats:
magstats = magstats.set_index("oid").replace({np.nan: None})
for oid in stats:
try:
stats[aid]["magstats"] = magstats.loc[aid].to_dict("records")
stats[oid]["magstats"] = magstats.loc[oid].to_dict("records")
except TypeError:
stats[aid]["magstats"] = [magstats.loc[aid].to_dict()]
stats[oid]["magstats"] = [magstats.loc[oid].to_dict()]

return stats

Expand All @@ -80,11 +80,11 @@ def execute(self, messages: dict):

# it seems that we'll have to produce different commands in this
def produce_scribe(self, result: dict):
for aid, stats in result.items():
for oid, stats in result.items():
command = {
"collection": "object",
"type": "update",
"criteria": {"_id": aid},
"criteria": {"_id": oid},
"data": stats
| {
"loc": {
Expand All @@ -107,7 +107,7 @@ def produce_scribe_ztf(self, result: dict):
{
"collection": "magstats",
"type": "upsert",
"criteria": {"oid": oid},
"criteria": {"_id": oid},
"data": stats,
}
for oid in oids
Expand Down
11 changes: 8 additions & 3 deletions magstats_step/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import uuid
import os
import pathlib

from confluent_kafka.admin import AdminClient, NewTopic
from apf.producers import KafkaProducer
Expand Down Expand Up @@ -56,7 +57,7 @@ def env_variables():
env_variables_dict = {
"PRODUCER_SCHEMA_PATH": "",
"CONSUMER_SCHEMA_PATH": "",
"METRIS_SCHEMA_PATH": "../schemas/magstats_step//metrics.json",
"METRIS_SCHEMA_PATH": "../schemas/magstats_step/metrics.json",
"SCRIBE_SCHEMA_PATH": "../schemas/scribe_step/scribe.avsc",
"CONSUMER_SERVER": "localhost:9092",
"CONSUMER_TOPICS": "correction",
Expand Down Expand Up @@ -95,8 +96,12 @@ def produce_messages(topic):
{
"PARAMS": {"bootstrap.servers": "localhost:9092"},
"TOPIC": topic,
"SCHEMA_PATH": os.path.join(
os.path.dirname(__file__), "../../schema.avsc"
"SCHEMA_PATH": str(
pathlib.Path(
pathlib.Path(__file__).parent.parent.parent.parent,
"schemas/correction_step",
"output.avsc",
)
),
}
)
Expand Down
2 changes: 1 addition & 1 deletion magstats_step/tests/integration/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def assert_message_schema(command):
if command["collection"] == "magstats":
assert command["type"] == "upsert"
assert "oid" in command["criteria"]
assert "_id" in command["criteria"]
elif command["collection"] == "object":
assert command["type"] == "update"
assert "_id" in command["criteria"]
Expand Down
21 changes: 14 additions & 7 deletions magstats_step/tests/unittests/data/messages.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import random
import pathlib

from fastavro import schema
from fastavro import utils

SCHEMA = schema.load_schema("schema.avsc")
SCHEMA_PATH = str(
pathlib.Path(
pathlib.Path(__file__).parent.parent.parent.parent.parent,
"schemas/correction_step",
"output.avsc",
)
)

SCHEMA = schema.load_schema(SCHEMA_PATH)
random.seed(42)

aids_pool = [f"AID22X{i}" for i in range(10)]
Expand All @@ -12,20 +21,18 @@
data = list(utils.generate_many(SCHEMA, 10))
for d in data:
aid = random.choice(aids_pool)
d["aid"] = aid
sid = "ZTF" if random.random() < 0.5 else "ATLAS"
oid = random.choice(oids_pool)
d["oid"] = oid
sid = "ZTF" if random.random() < 0.5 else "ATLAS"
for detection in d["detections"]:
detection["aid"] = aid
detection["oid"] = oid
detection["sid"] = sid
detection["fid"] = "g" if random.random() < 0.5 else "r"
detection["forced"] = False

if sid == "ZTF":
detection["oid"] = oid
for non_detection in d["non_detections"]:
non_detection["aid"] = aid
non_detection["oid"] = oid
non_detection["sid"] = sid
non_detection["fid"] = "g" if random.random() < 0.5 else "r"
if sid == "ZTF":
non_detection["oid"] = oid
Loading

0 comments on commit eecf416

Please sign in to comment.