diff --git a/.github/workflows/staging_ci.yaml b/.github/workflows/staging_ci.yaml index 66db25bb7..9d998ffda 100644 --- a/.github/workflows/staging_ci.yaml +++ b/.github/workflows/staging_ci.yaml @@ -284,6 +284,7 @@ jobs: poetry run python main.py deploy add-package alert-archive-step --chart=alert-archive-step --values=alert_archive_step-helm-values --chart-folder=alert_archiving_step poetry run python main.py deploy add-package correction-step --values=correction_step-helm-values --chart-folder=correction_step poetry run python main.py deploy add-package early-classification-step --chart=early-classifier --values=early_classification_step-helm-values --chart-folder=early_classification_step + poetry run python main.py deploy add-package feature-step-atlas --values=feature_step_atlas-helm-values --chart-folder=feature_step poetry run python main.py deploy add-package feature-step --values=feature_step-helm-values --chart-folder=feature_step poetry run python main.py deploy add-package lc-classifier-step-ztf --chart=lc-classifier-step --values=lc_classification_step-helm-values --chart-folder=lc_classification_step poetry run python main.py deploy add-package lightcurve-step --values=lightcurve-step-helm-values diff --git a/feature_step/features/core/atlas.py b/feature_step/features/core/atlas.py new file mode 100644 index 000000000..7bf4a974f --- /dev/null +++ b/feature_step/features/core/atlas.py @@ -0,0 +1,21 @@ +from typing import List +import pandas as pd +from importlib import metadata as pymetadata + + +class AtlasFeatureExtractor: + NAME = "atlas_lc_features" + VERSION = pymetadata.version("feature-step") + SURVEYS = ("ATLAS",) + + def __init__( + self, + detections: List[dict], + non_detections: List[dict], + xmatch: List[dict], + **kwargs, + ): + pass + + def generate_features(self): + return pd.DataFrame() diff --git a/feature_step/features/core/handlers/detections.py b/feature_step/features/core/handlers/detections.py index 337358fc1..71e61903e 100644 --- a/feature_step/features/core/handlers/detections.py +++ b/feature_step/features/core/handlers/detections.py @@ -6,6 +6,10 @@ from ._base import BaseHandler +class NoDetectionsException(Exception): + pass + + class DetectionsHandler(BaseHandler): """Class for handling detections. @@ -37,6 +41,19 @@ class DetectionsHandler(BaseHandler): "forced", ] + def __init__( + self, + alerts: list[dict], + *, + surveys: str | tuple[str] = (), + bands: str | tuple[str] = (), + **kwargs, + ): + if len(alerts) == 0: + raise NoDetectionsException() + + super().__init__(alerts, surveys=surveys, bands=bands, **kwargs) + def _post_process(self, **kwargs): """Handles legacy alerts (renames old field names to the new conventions) and sets the `mag_ml` and `e_mag_ml` fields based on the `corr` argument. This is in addition to base diff --git a/feature_step/features/step.py b/feature_step/features/step.py index 0bd1e3aa9..644aeb7ee 100644 --- a/feature_step/features/step.py +++ b/feature_step/features/step.py @@ -7,6 +7,7 @@ from features.core.elasticc import ELAsTiCCFeatureExtractor from features.core.ztf import ZTFFeatureExtractor +from features.core.handlers.detections import NoDetectionsException from features.utils.metrics import get_sid from features.utils.parsers import parse_output, parse_scribe_payload @@ -81,9 +82,13 @@ def execute(self, messages): ] messages_aid_oid[message["aid"]] = list(set(oids_of_aid)) - features_extractor = self.features_extractor( - detections, non_detections, xmatch - ) + try: + features_extractor = self.features_extractor( + detections, non_detections, xmatch + ) + except NoDetectionsException: + return [] + features = features_extractor.generate_features() if len(features) > 0: self.produce_to_scribe(messages_aid_oid, features) diff --git a/feature_step/features/utils/parsers.py b/feature_step/features/utils/parsers.py index 78b334668..1925ebce1 100644 --- a/feature_step/features/utils/parsers.py +++ b/feature_step/features/utils/parsers.py @@ -25,6 +25,8 @@ def parse_scribe_payload( ) if extractor_class.NAME == "elasticc_lc_features": return _parse_scribe_payload_elasticc(features, extractor_class) + if extractor_class.NAME == "atlas_lc_features": + return _parse_scribe_payload_atlas() else: raise Exception( 'Cannot parse scribe payload for extractor "{}"'.format( @@ -92,6 +94,10 @@ def _parse_scribe_payload_ztf( return commands_list +def _parse_scribe_payload_atlas(): + return [] + + def parse_output( features: pd.DataFrame, alert_data: list[dict], @@ -118,6 +124,8 @@ def parse_output( return _parse_output_elasticc( features, alert_data, extractor_class, candids ) + elif extractor_class.NAME == "atlas_lc_features": + return _parse_output_atlas() else: raise Exception( 'Cannot parse output for extractor "{}"'.format( @@ -189,3 +197,7 @@ def _parse_output_ztf(features, alert_data, extractor_class, candids): output_messages.append(out_message) return output_messages + + +def _parse_output_atlas(): + return [] diff --git a/feature_step/features/utils/selector.py b/feature_step/features/utils/selector.py index 8fbe45219..59bf680b0 100644 --- a/feature_step/features/utils/selector.py +++ b/feature_step/features/utils/selector.py @@ -1,4 +1,5 @@ from features.core.elasticc import ELAsTiCCFeatureExtractor +from features.core.atlas import AtlasFeatureExtractor from features.core.ztf import ZTFFeatureExtractor from typing import Callable @@ -11,9 +12,15 @@ def __init__(self, name) -> None: def selector( name: str, -) -> type[ZTFFeatureExtractor] | type[ELAsTiCCFeatureExtractor]: +) -> ( + type[ZTFFeatureExtractor] + | type[ELAsTiCCFeatureExtractor] + | type[AtlasFeatureExtractor] +): if name.lower() == "ztf": return ZTFFeatureExtractor if name.lower() == "elasticc": return ELAsTiCCFeatureExtractor + if name.lower() == "atlas": + return AtlasFeatureExtractor raise ExtractorNotFoundException(name) diff --git a/feature_step/tests/integration/conftest.py b/feature_step/tests/integration/conftest.py index 603fa480c..1c5385933 100644 --- a/feature_step/tests/integration/conftest.py +++ b/feature_step/tests/integration/conftest.py @@ -23,6 +23,12 @@ os.environ["SCRIBE_TOPIC"] = "" +SCHEMA_PATH = pathlib.Path( + pathlib.Path(__file__).parent.parent.parent.parent, + "schemas/xmatch_step", + "output.avsc", +) + @pytest.fixture(scope="session") def docker_compose_file(pytestconfig): try: @@ -63,29 +69,60 @@ def kafka_service(docker_ip, docker_services): docker_services.wait_until_responsive( timeout=30.0, pause=0.1, check=lambda: is_responsive_kafka(server) ) - schema_path = pathlib.Path( - pathlib.Path(__file__).parent.parent.parent.parent, - "schemas/xmatch_step", - "output.avsc", - ) + return server + +@pytest.fixture(scope="session") +def elasticc_messages(kafka_service): config = { "PARAMS": {"bootstrap.servers": "localhost:9092"}, "TOPIC": "elasticc", - "SCHEMA_PATH": schema_path, + "SCHEMA_PATH": SCHEMA_PATH, } producer = KafkaProducer(config) data_elasticc = generate_elasticc_batch(5, ELASTICC_BANDS) - data_ztf = generate_input_batch(5) for data in data_elasticc: producer.produce(data) - producer.producer.flush(10) + producer.producer.flush(5) + +@pytest.fixture(scope="session") +def ztf_messages(kafka_service): config = { "PARAMS": {"bootstrap.servers": "localhost:9092"}, "TOPIC": "ztf", - "SCHEMA_PATH": schema_path, + "SCHEMA_PATH": SCHEMA_PATH, } + data_ztf = generate_input_batch(5) producer = KafkaProducer(config) for data in data_ztf: producer.produce(data) - producer.producer.flush(10) - return server + producer.producer.flush(5) + +@pytest.fixture(scope="session") +def atlas_messages(kafka_service): + # Produce messages for atlas test + config = { + "PARAMS": {"bootstrap.servers": "localhost:9092"}, + "TOPIC": "atlas", + "SCHEMA_PATH": SCHEMA_PATH, + } + data_atlas = generate_input_batch(5) + producer = KafkaProducer(config) + for data in data_atlas: + data["tid"] = "atlas" + producer.produce(data) + producer.producer.flush(5) + +@pytest.fixture(scope="session") +def atlas_messages_ztf_topic(kafka_service): + # Produce messages for atlas test + config = { + "PARAMS": {"bootstrap.servers": "localhost:9092"}, + "TOPIC": "ztf", + "SCHEMA_PATH": SCHEMA_PATH, + } + data_atlas = generate_input_batch(5) + producer = KafkaProducer(config) + for data in data_atlas: + data["tid"] = "atlas" + producer.produce(data) + producer.producer.flush(5) diff --git a/feature_step/tests/integration/test_step.py b/feature_step/tests/integration/test_step.py index 0edd33830..1369a7ea6 100644 --- a/feature_step/tests/integration/test_step.py +++ b/feature_step/tests/integration/test_step.py @@ -43,7 +43,7 @@ } -def test_step_ztf(kafka_service): +def test_step_ztf(ztf_messages): CONSUMER_CONFIG["TOPICS"] = ["ztf"] step_config = { "PRODUCER_CONFIG": PRODUCER_CONFIG, @@ -54,11 +54,12 @@ def test_step_ztf(kafka_service): step = FeaturesComputer( extractor, config=step_config, + ) step.start() -def test_step_elasticc(kafka_service): +def test_step_elasticc(elasticc_messages): CONSUMER_CONFIG["TOPICS"] = ["elasticc"] step_config = { "PRODUCER_CONFIG": PRODUCER_CONFIG, @@ -71,3 +72,31 @@ def test_step_elasticc(kafka_service): config=step_config, ) step.start() + +def test_step_atlas(atlas_messages): + CONSUMER_CONFIG["TOPICS"] = ["atlas"] + step_config = { + "PRODUCER_CONFIG": PRODUCER_CONFIG, + "CONSUMER_CONFIG": CONSUMER_CONFIG, + "SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG, + } + extractor = selector("atlas") + step = FeaturesComputer( + extractor, + config=step_config, + ) + step.start() + +def test_step_ztf_atlas_messages(atlas_messages_ztf_topic): + CONSUMER_CONFIG["TOPICS"] = ["ztf"] + step_config = { + "PRODUCER_CONFIG": PRODUCER_CONFIG, + "CONSUMER_CONFIG": CONSUMER_CONFIG, + "SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG, + } + extractor = selector("ztf") + step = FeaturesComputer( + extractor, + config=step_config, + ) + step.start() diff --git a/feature_step/tests/unittest/test_extractor_selector.py b/feature_step/tests/unittest/test_extractor_selector.py index a3c9a9694..8434e160e 100644 --- a/feature_step/tests/unittest/test_extractor_selector.py +++ b/feature_step/tests/unittest/test_extractor_selector.py @@ -25,6 +25,15 @@ def test_get_elasticc_extractor(self): selected_extractor = selector(input_str) self.assertEqual(selected_extractor.NAME, "elasticc_lc_features") + def test_get_elasticc_extractor(self): + input_str = "atlas" + selected_extractor = selector(input_str) + self.assertEqual(selected_extractor.NAME, "atlas_lc_features") + + input_str = "ATLAS" + selected_extractor = selector(input_str) + self.assertEqual(selected_extractor.NAME, "atlas_lc_features") + def test_extractor_not_found(self): input_str = "dummy" with self.assertRaises(ExtractorNotFoundException): diff --git a/feature_step/tests/unittest/test_step.py b/feature_step/tests/unittest/test_step.py index 1b3283af2..e1b8f8a8f 100644 --- a/feature_step/tests/unittest/test_step.py +++ b/feature_step/tests/unittest/test_step.py @@ -8,6 +8,7 @@ features_df_for_execute, messages_for_execute, ) +from features.utils.selector import selector CONSUMER_CONFIG = { "CLASS": "unittest.mock.MagicMock", @@ -354,3 +355,39 @@ def test_execute(self): self.step.scribe_producer.produce.call_count ) self.assertEqual(scribe_producer_call_count, 2) + +class ZTFWithAtlasDataTestCase(unittest.TestCase): + def setUp(self): + self.maxDiff = None + self.step_config = { + "PRODUCER_CONFIG": PRODUCER_CONFIG, + "CONSUMER_CONFIG": CONSUMER_CONFIG, + "SCRIBE_PRODUCER_CONFIG": SCRIBE_PRODUCER_CONFIG, + "FEATURE_VERSION": "v1", + "STEP_METADATA": { + "STEP_VERSION": "feature", + "STEP_ID": "feature", + "STEP_NAME": "feature", + "STEP_COMMENTS": "feature", + "FEATURE_VERSION": "1.0-test", + }, + } + extractor = selector("ztf") + self.step = FeaturesComputer( + config=self.step_config, extractor=extractor + ) + self.step.scribe_producer = mock.create_autospec(GenericProducer) + self.step.scribe_producer.produce = mock.MagicMock() + + def test_execute_ztf_extractor_with_atlas_messages(self): + import copy + expected_output = [] + result = self.step.execute(messages_for_execute) + + messages_for_test = copy.deepcopy(messages_for_execute) + for message in messages_for_test: + for detection in message["detections"]: + detection["tid"] = "atlas" + + self.assertEqual(result, expected_output) + self.step.scribe_producer.produce.assert_not_called()