Skip to content

Commit

Permalink
Merge pull request #502 from alercebroker/fix/wl_filter_parsing
Browse files Browse the repository at this point in the history
Fix parsing problems in watchlists
  • Loading branch information
HectorxH authored Nov 14, 2024
2 parents 46c4a4a + 61c4445 commit 1b203c0
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 88 deletions.
4 changes: 3 additions & 1 deletion watchlist_step/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

.ruff_cache
.ruff_cache

sample.csv
8 changes: 6 additions & 2 deletions watchlist_step/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ packages = [{ include = "watchlist_step" }]

[tool.poetry.scripts]
step = "scripts.run_step:step"
sample = "scripts.make_sample:run"

[tool.poetry.dependencies]
python = "~3.10"
Expand All @@ -24,8 +25,11 @@ pytest-cov = "^4.1.0"
black = "^23.0.0"
ruff = "^0.3.5"

[tool.black]
line-length = 88
[tool.ruff]
line-length = 80

[tool.pyright]
typeCheckingMode = "basic"

[build-system]
requires = ["poetry-core"]
Expand Down
48 changes: 48 additions & 0 deletions watchlist_step/scripts/make_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from typing import Any

from psycopg2.sql import SQL

from watchlist_step.db.connection import DatabaseConfig, PsqlDatabase

config: DatabaseConfig = {
"HOST": os.environ["ALERTS_DB_HOST"],
"USER": os.environ["ALERTS_DB_USER"],
"PASSWORD": os.environ["ALERTS_DB_PASSWORD"],
"PORT": os.getenv("ALERTS_DB_PORT", "5432"),
"DB_NAME": os.environ["ALERTS_DB_NAME"],
}


def run(p: float = 0.0025, seed: int = 42, radius: float = 1):
alerts_db = PsqlDatabase(config)
rows: list[tuple[Any, ...]]
with alerts_db.conn() as conn:
with conn.cursor() as cursor:
query = SQL("""
SELECT
p.oid,
meanra,
meandec
FROM
probability p
TABLESAMPLE SYSTEM (%s) REPEATABLE (%s)
JOIN
object o
ON
o.oid = p.oid
WHERE
classifier_name = 'lc_classifier'
AND class_name = 'AGN'
AND ranking = 1
""")
cursor.execute(query, (p * 100.0, seed))
rows = cursor.fetchall()
with open("./sample.csv", "w+") as file:
file.write("name,ra,dec,radius\n")
file.writelines(
map(
lambda row: f"{row[0]},{row[1]:.5f},{row[2]:.4f},{radius}\n",
rows,
)
)
2 changes: 0 additions & 2 deletions watchlist_step/scripts/run_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
import sys

from apf.core import get_class

import settings
from watchlist_step.step import WatchlistStep

Expand Down
21 changes: 16 additions & 5 deletions watchlist_step/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
CONSUMER_CONFIG["TOPIC_STRATEGY"] = {
"CLASS": "apf.core.topic_management.DailyTopicStrategy",
"PARAMS": {
"topic_format": os.environ["TOPIC_STRATEGY_FORMAT"].strip().split(","),
"topic_format": os.environ["TOPIC_STRATEGY_FORMAT"]
.strip()
.split(","),
"date_format": "%Y%m%d",
"change_hour": 23,
},
Expand Down Expand Up @@ -57,7 +59,10 @@
"description": "The root schema comprises the entire JSON document.",
"default": {},
"examples": [
{"timestamp_sent": "2020-09-01", "timestamp_received": "2020-09-01"}
{
"timestamp_sent": "2020-09-01",
"timestamp_received": "2020-09-01",
}
],
"required": ["timestamp_sent", "timestamp_received"],
"properties": {
Expand All @@ -83,11 +88,17 @@
},
}

if os.getenv("CONSUMER_KAFKA_USERNAME") and os.getenv("CONSUMER_KAFKA_PASSWORD"):
if os.getenv("CONSUMER_KAFKA_USERNAME") and os.getenv(
"CONSUMER_KAFKA_PASSWORD"
):
CONSUMER_CONFIG["PARAMS"]["security.protocol"] = "SASL_SSL"
CONSUMER_CONFIG["PARAMS"]["sasl.mechanism"] = "SCRAM-SHA-512"
CONSUMER_CONFIG["PARAMS"]["sasl.username"] = os.getenv("CONSUMER_KAFKA_USERNAME")
CONSUMER_CONFIG["PARAMS"]["sasl.password"] = os.getenv("CONSUMER_KAFKA_PASSWORD")
CONSUMER_CONFIG["PARAMS"]["sasl.username"] = os.getenv(
"CONSUMER_KAFKA_USERNAME"
)
CONSUMER_CONFIG["PARAMS"]["sasl.password"] = os.getenv(
"CONSUMER_KAFKA_PASSWORD"
)
if os.getenv("METRICS_KAFKA_USERNAME") and os.getenv("METRICS_KAFKA_PASSWORD"):
METRICS_CONFIG["PARAMS"]["PARAMS"]["security.protocol"] = "SASL_SSL"
METRICS_CONFIG["PARAMS"]["PARAMS"]["sasl.mechanism"] = "SCRAM-SHA-512"
Expand Down
15 changes: 6 additions & 9 deletions watchlist_step/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import os
from io import BytesIO
import pathlib

import psycopg2
import pytest
from apf.producers import KafkaProducer, KafkaSchemalessProducer
from apf.consumers import KafkaConsumer
from confluent_kafka import Producer
from apf.producers import KafkaSchemalessProducer
from confluent_kafka.admin import AdminClient, NewTopic
from fastavro.schema import load_schema
from fastavro.utils import generate_many
from confluent_kafka.admin import AdminClient, NewTopic
from fastavro import writer, parse_schema

from watchlist_step.db.connection import PsqlDatabase
from tests.integration.mocks.mock_alerts import ztf_extra_fields_generator


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -53,10 +49,11 @@ def produce_message(topic):
message["dec"] = 53.34521158573315
message["candid"] = str(1000151433015015014 + i)
producer.produce(message)

producer.producer.flush()
del producer


def consume_message(config):
consumer = KafkaConsumer(config)
for msg in consumer.consume():
Expand All @@ -72,7 +69,7 @@ def is_responsive_kafka(url):
try:
f.result()
return True
except Exception as e:
except Exception:
return False


Expand Down Expand Up @@ -101,7 +98,7 @@ def is_responsive_users_database(docker_ip, port):
)
conn.close()
return True
except Exception as e:
except Exception:
return False


Expand Down
13 changes: 7 additions & 6 deletions watchlist_step/tests/integration/test_step_integration.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import pytest
import pathlib
from apf.consumers import KafkaConsumer
from watchlist_step.step import WatchlistStep

import pytest

from watchlist_step.step import WatchlistStep

SORTING_HAT_SCHEMA_PATH = pathlib.Path(
pathlib.Path(__file__).parent.parent.parent.parent,
"schemas/sorting_hat_step",
"output.avsc",
)


@pytest.fixture
def step_creator():
def create_step(strategy_name, config):
Expand All @@ -30,7 +31,7 @@ class TestStep:
"bootstrap.servers": "localhost:9092",
"group.id": "test_integration",
"auto.offset.reset": "beginning",
"enable.partition.eof": True
"enable.partition.eof": True,
},
"consume.timeout": 30,
"consume.messages": 2,
Expand All @@ -45,7 +46,7 @@ class TestStep:
"PASSWORD": "postgres",
"PORT": 5432,
"DB_NAME": "postgres",
}
},
}

def test_should_insert_matches_if_matches_returned(
Expand All @@ -55,7 +56,7 @@ def test_should_insert_matches_if_matches_returned(
step_creator,
):
self.consumer_config["PARAMS"]["group.id"] = "test_integration"
#consumer = KafkaConsumer(self.consumer_config)
# consumer = KafkaConsumer(self.consumer_config)
strategy_name = "SortingHat"
step = step_creator(
strategy_name,
Expand Down
2 changes: 1 addition & 1 deletion watchlist_step/tests/unittest/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def wl_step():
"PASSWORD": "password",
"PORT": 5433,
"DB_NAME": "postgres",
}
},
}
return WatchlistStep(
strategy_name=strategy_name,
Expand Down
39 changes: 26 additions & 13 deletions watchlist_step/tests/unittest/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,22 @@
class TestSortingHatStrategy:
sorting_hat_strat = SortingHatStrategy()
alerts = {
(3, 4): {"ra": 1, "dec": 2, "oid": 3, "candid": 4, "mjd": 10, "mag": 12},
(7, 8): {"ra": 5, "dec": 6, "oid": 7, "candid": 8, "mjd": 11, "mag": 13},
(3, 4): {
"ra": 1,
"dec": 2,
"oid": 3,
"candid": 4,
"mjd": 10,
"mag": 12,
},
(7, 8): {
"ra": 5,
"dec": 6,
"oid": 7,
"candid": 8,
"mjd": 11,
"mag": 13,
},
(1, 2): {"ra": 7, "dec": 5, "oid": 1, "candid": 2, "mjd": 15, "mag": 7},
}

Expand All @@ -14,7 +28,10 @@ class TestSortingHatStrategy:
3,
4,
100,
{"fields": {"sorting_hat": ["mag"], "features": ["a", "b"]}, "filters": {}},
{
"fields": {"sorting_hat": ["mag"], "features": ["a", "b"]},
"filters": {},
},
),
(
7,
Expand Down Expand Up @@ -42,14 +59,10 @@ def test_get_coordinates(self):
assert result[1] == (5, 6, 7, 8)

def test_get_new_values(self):
result_values, result_filters = self.sorting_hat_strat.get_new_values(
self.matches, self.alerts
)

assert len(result_filters) == 2
assert result_filters[0] == self.matches[0][3]
assert result_filters[1] == self.matches[1][3]
new_values = self.sorting_hat_strat.get_new_values(self.matches, self.alerts)

assert len(result_values) == 2
assert result_values[0] == (3, 4, 100, {"mag": 12})
assert result_values[1] == (7, 8, 101, {"mag": 13, "mjd": 11})
assert len(new_values) == 3
assert new_values[0][-1] == self.matches[0][3]
assert new_values[1][-1] == self.matches[1][3]
assert new_values[0][:-1] == (3, 4, 100, {"mag": 12})
assert new_values[1][:-1] == (7, 8, 101, {"mag": 13, "mjd": 11})
9 changes: 8 additions & 1 deletion watchlist_step/watchlist_step/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import TypedDict

import psycopg2

DatabaseConfig = TypedDict(
"DatabaseConfig",
{"USER": str, "PASSWORD": str, "HOST": str, "PORT": str, "DB_NAME": str},
)


class PsqlDatabase:
def __init__(self, config: dict) -> None:
def __init__(self, config: DatabaseConfig) -> None:
url = self.__format_db_url(config)
self.conn = lambda: psycopg2.connect(url)

Expand Down
8 changes: 1 addition & 7 deletions watchlist_step/watchlist_step/db/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@ def create_match_query(len, base_radius=30 / 3600):
watchlist_target.dec,
positions.ra,
positions.dec,
{}
LEAST(watchlist_target.radius, {})
)
AND q3c_dist(
watchlist_target.ra,
watchlist_target.dec,
positions.ra,
positions.dec
) < watchlist_target.radius
"""
).format(
SQL(", ").join(SQL("(%s, %s, %s, %s)") for _ in range(len)),
Expand Down
4 changes: 2 additions & 2 deletions watchlist_step/watchlist_step/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ def satisfies_filter(values: dict, type: str, params: dict) -> bool:
match type:
case "constant":
return constant(values, **params)
case "all":
case "and":
return _all(values, **params)
case "any":
case "or":
return _any(values, **params)
case "no filter":
return True
Expand Down
Loading

0 comments on commit 1b203c0

Please sign in to comment.