Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atlas features placeholder #325

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/staging_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ jobs:
poetry run python main.py deploy add-package alert-archive-step --chart=alert-archive-step --values=alert_archive_step-helm-values --chart-folder=alert_archiving_step
poetry run python main.py deploy add-package correction-step --values=correction_step-helm-values --chart-folder=correction_step
poetry run python main.py deploy add-package early-classification-step --chart=early-classifier --values=early_classification_step-helm-values --chart-folder=early_classification_step
poetry run python main.py deploy add-package feature-step-atlas --values=feature_step_atlas-helm-values --chart-folder=feature_step
poetry run python main.py deploy add-package feature-step --values=feature_step-helm-values --chart-folder=feature_step
poetry run python main.py deploy add-package lc-classifier-step-ztf --chart=lc-classifier-step --values=lc_classification_step-helm-values --chart-folder=lc_classification_step
poetry run python main.py deploy add-package lightcurve-step --values=lightcurve-step-helm-values
Expand Down
21 changes: 21 additions & 0 deletions feature_step/features/core/atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List
import pandas as pd
from importlib import metadata as pymetadata


class AtlasFeatureExtractor:
NAME = "atlas_lc_features"
VERSION = pymetadata.version("feature-step")
SURVEYS = ("ATLAS",)

def __init__(
self,
detections: List[dict],
non_detections: List[dict],
xmatch: List[dict],
**kwargs,
):
pass

def generate_features(self):
return pd.DataFrame()
17 changes: 17 additions & 0 deletions feature_step/features/core/handlers/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from ._base import BaseHandler


class NoDetectionsException(Exception):
pass


class DetectionsHandler(BaseHandler):
"""Class for handling detections.

Expand Down Expand Up @@ -37,6 +41,19 @@ class DetectionsHandler(BaseHandler):
"forced",
]

def __init__(
self,
alerts: list[dict],
*,
surveys: str | tuple[str] = (),
bands: str | tuple[str] = (),
**kwargs,
):
if len(alerts) == 0:
raise NoDetectionsException()

super().__init__(alerts, surveys=surveys, bands=bands, **kwargs)

def _post_process(self, **kwargs):
"""Handles legacy alerts (renames old field names to the new conventions) and sets the
`mag_ml` and `e_mag_ml` fields based on the `corr` argument. This is in addition to base
Expand Down
11 changes: 8 additions & 3 deletions feature_step/features/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from features.core.elasticc import ELAsTiCCFeatureExtractor
from features.core.ztf import ZTFFeatureExtractor
from features.core.handlers.detections import NoDetectionsException
from features.utils.metrics import get_sid
from features.utils.parsers import parse_output, parse_scribe_payload

Expand Down Expand Up @@ -81,9 +82,13 @@ def execute(self, messages):
]
messages_aid_oid[message["aid"]] = list(set(oids_of_aid))

features_extractor = self.features_extractor(
detections, non_detections, xmatch
)
try:
features_extractor = self.features_extractor(
detections, non_detections, xmatch
)
except NoDetectionsException:
return []

features = features_extractor.generate_features()
if len(features) > 0:
self.produce_to_scribe(messages_aid_oid, features)
Expand Down
12 changes: 12 additions & 0 deletions feature_step/features/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def parse_scribe_payload(
)
if extractor_class.NAME == "elasticc_lc_features":
return _parse_scribe_payload_elasticc(features, extractor_class)
if extractor_class.NAME == "atlas_lc_features":
return _parse_scribe_payload_atlas()
else:
raise Exception(
'Cannot parse scribe payload for extractor "{}"'.format(
Expand Down Expand Up @@ -92,6 +94,10 @@ def _parse_scribe_payload_ztf(
return commands_list


def _parse_scribe_payload_atlas():
return []


def parse_output(
features: pd.DataFrame,
alert_data: list[dict],
Expand All @@ -118,6 +124,8 @@ def parse_output(
return _parse_output_elasticc(
features, alert_data, extractor_class, candids
)
elif extractor_class.NAME == "atlas_lc_features":
return _parse_output_atlas()
else:
raise Exception(
'Cannot parse output for extractor "{}"'.format(
Expand Down Expand Up @@ -189,3 +197,7 @@ def _parse_output_ztf(features, alert_data, extractor_class, candids):
output_messages.append(out_message)

return output_messages


def _parse_output_atlas():
return []
9 changes: 8 additions & 1 deletion feature_step/features/utils/selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from features.core.elasticc import ELAsTiCCFeatureExtractor
from features.core.atlas import AtlasFeatureExtractor
from features.core.ztf import ZTFFeatureExtractor
from typing import Callable

Expand All @@ -11,9 +12,15 @@ def __init__(self, name) -> None:

def selector(
name: str,
) -> type[ZTFFeatureExtractor] | type[ELAsTiCCFeatureExtractor]:
) -> (
type[ZTFFeatureExtractor]
| type[ELAsTiCCFeatureExtractor]
| type[AtlasFeatureExtractor]
):
if name.lower() == "ztf":
return ZTFFeatureExtractor
if name.lower() == "elasticc":
return ELAsTiCCFeatureExtractor
if name.lower() == "atlas":
return AtlasFeatureExtractor
raise ExtractorNotFoundException(name)
59 changes: 48 additions & 11 deletions feature_step/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
os.environ["SCRIBE_TOPIC"] = ""


SCHEMA_PATH = pathlib.Path(
pathlib.Path(__file__).parent.parent.parent.parent,
"schemas/xmatch_step",
"output.avsc",
)

@pytest.fixture(scope="session")
def docker_compose_file(pytestconfig):
try:
Expand Down Expand Up @@ -63,29 +69,60 @@ def kafka_service(docker_ip, docker_services):
docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=lambda: is_responsive_kafka(server)
)
schema_path = pathlib.Path(
pathlib.Path(__file__).parent.parent.parent.parent,
"schemas/xmatch_step",
"output.avsc",
)
return server

@pytest.fixture(scope="session")
def elasticc_messages(kafka_service):
config = {
"PARAMS": {"bootstrap.servers": "localhost:9092"},
"TOPIC": "elasticc",
"SCHEMA_PATH": schema_path,
"SCHEMA_PATH": SCHEMA_PATH,
}
producer = KafkaProducer(config)
data_elasticc = generate_elasticc_batch(5, ELASTICC_BANDS)
data_ztf = generate_input_batch(5)
for data in data_elasticc:
producer.produce(data)
producer.producer.flush(10)
producer.producer.flush(5)

@pytest.fixture(scope="session")
def ztf_messages(kafka_service):
config = {
"PARAMS": {"bootstrap.servers": "localhost:9092"},
"TOPIC": "ztf",
"SCHEMA_PATH": schema_path,
"SCHEMA_PATH": SCHEMA_PATH,
}
data_ztf = generate_input_batch(5)
producer = KafkaProducer(config)
for data in data_ztf:
producer.produce(data)
producer.producer.flush(10)
return server
producer.producer.flush(5)

@pytest.fixture(scope="session")
def atlas_messages(kafka_service):
# Produce messages for atlas test
config = {
"PARAMS": {"bootstrap.servers": "localhost:9092"},
"TOPIC": "atlas",
"SCHEMA_PATH": SCHEMA_PATH,
}
data_atlas = generate_input_batch(5)
producer = KafkaProducer(config)
for data in data_atlas:
data["tid"] = "atlas"
producer.produce(data)
producer.producer.flush(5)

@pytest.fixture(scope="session")
def atlas_messages_ztf_topic(kafka_service):
# Produce messages for atlas test
config = {
"PARAMS": {"bootstrap.servers": "localhost:9092"},
"TOPIC": "ztf",
"SCHEMA_PATH": SCHEMA_PATH,
}
data_atlas = generate_input_batch(5)
producer = KafkaProducer(config)
for data in data_atlas:
data["tid"] = "atlas"
producer.produce(data)
producer.producer.flush(5)
33 changes: 31 additions & 2 deletions feature_step/tests/integration/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
}


def test_step_ztf(kafka_service):
def test_step_ztf(ztf_messages):
CONSUMER_CONFIG["TOPICS"] = ["ztf"]
step_config = {
"PRODUCER_CONFIG": PRODUCER_CONFIG,
Expand All @@ -54,11 +54,12 @@ def test_step_ztf(kafka_service):
step = FeaturesComputer(
extractor,
config=step_config,

)
step.start()


def test_step_elasticc(kafka_service):
def test_step_elasticc(elasticc_messages):
CONSUMER_CONFIG["TOPICS"] = ["elasticc"]
step_config = {
"PRODUCER_CONFIG": PRODUCER_CONFIG,
Expand All @@ -71,3 +72,31 @@ def test_step_elasticc(kafka_service):
config=step_config,
)
step.start()

def test_step_atlas(atlas_messages):
CONSUMER_CONFIG["TOPICS"] = ["atlas"]
step_config = {
"PRODUCER_CONFIG": PRODUCER_CONFIG,
"CONSUMER_CONFIG": CONSUMER_CONFIG,
"SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG,
}
extractor = selector("atlas")
step = FeaturesComputer(
extractor,
config=step_config,
)
step.start()

def test_step_ztf_atlas_messages(atlas_messages_ztf_topic):
CONSUMER_CONFIG["TOPICS"] = ["ztf"]
step_config = {
"PRODUCER_CONFIG": PRODUCER_CONFIG,
"CONSUMER_CONFIG": CONSUMER_CONFIG,
"SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG,
}
extractor = selector("ztf")
step = FeaturesComputer(
extractor,
config=step_config,
)
step.start()
9 changes: 9 additions & 0 deletions feature_step/tests/unittest/test_extractor_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def test_get_elasticc_extractor(self):
selected_extractor = selector(input_str)
self.assertEqual(selected_extractor.NAME, "elasticc_lc_features")

def test_get_elasticc_extractor(self):
input_str = "atlas"
selected_extractor = selector(input_str)
self.assertEqual(selected_extractor.NAME, "atlas_lc_features")

input_str = "ATLAS"
selected_extractor = selector(input_str)
self.assertEqual(selected_extractor.NAME, "atlas_lc_features")

def test_extractor_not_found(self):
input_str = "dummy"
with self.assertRaises(ExtractorNotFoundException):
Expand Down
37 changes: 37 additions & 0 deletions feature_step/tests/unittest/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
features_df_for_execute,
messages_for_execute,
)
from features.utils.selector import selector

CONSUMER_CONFIG = {
"CLASS": "unittest.mock.MagicMock",
Expand Down Expand Up @@ -354,3 +355,39 @@ def test_execute(self):
self.step.scribe_producer.produce.call_count
)
self.assertEqual(scribe_producer_call_count, 2)

class ZTFWithAtlasDataTestCase(unittest.TestCase):
def setUp(self):
self.maxDiff = None
self.step_config = {
"PRODUCER_CONFIG": PRODUCER_CONFIG,
"CONSUMER_CONFIG": CONSUMER_CONFIG,
"SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG,
"FEATURE_VERSION": "v1",
"STEP_METADATA": {
"STEP_VERSION": "feature",
"STEP_ID": "feature",
"STEP_NAME": "feature",
"STEP_COMMENTS": "feature",
"FEATURE_VERSION": "1.0-test",
},
}
extractor = selector("ztf")
self.step = FeaturesComputer(
config=self.step_config, extractor=extractor
)
self.step.scribe_producer = mock.create_autospec(GenericProducer)
self.step.scribe_producer.produce = mock.MagicMock()

def test_execute_ztf_extractor_with_atlas_messages(self):
import copy
expected_output = []
result = self.step.execute(messages_for_execute)

messages_for_test = copy.deepcopy(messages_for_execute)
for message in messages_for_test:
for detection in message["detections"]:
detection["tid"] = "atlas"

self.assertEqual(result, expected_output)
self.step.scribe_producer.produce.assert_not_called()