diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py index 30a1957a8..20f8157c0 100644 --- a/astronomer/providers/amazon/aws/triggers/s3.py +++ b/astronomer/providers/amazon/aws/triggers/s3.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from datetime import datetime from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -91,103 +90,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]: def _get_async_hook(self) -> S3HookAsync: return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) - - -class S3KeysUnchangedTrigger(BaseTrigger): - """ - S3KeyTrigger is fired as deferred class with params to run the task in trigger worker - - :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` - is not provided as a full s3:// url. - :param prefix: The prefix being waited on. Relative path from bucket root level. - :param inactivity_period: The total seconds of inactivity to designate - keys unchanged. Note, this mechanism is not real time and - this operator may not return until a poke_interval after this period - has passed with no additional objects sensed. - :param min_objects: The minimum number of objects needed for keys unchanged - sensor to be considered valid. - :param inactivity_seconds: reference to the seconds of inactivity - :param previous_objects: The set of object ids found during the last poke. - :param allow_delete: Should this sensor consider objects being deleted - :param aws_conn_id: reference to the s3 connection - :param last_activity_time: last modified or last active time - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - """ - - def __init__( - self, - bucket_name: str, - prefix: str, - inactivity_period: float = 60 * 60, - min_objects: int = 1, - inactivity_seconds: int = 0, - previous_objects: set[str] | None = None, - allow_delete: bool = True, - aws_conn_id: str = "aws_default", - last_activity_time: datetime | None = None, - verify: bool | str | None = None, - ): - super().__init__() - self.bucket_name = bucket_name - self.prefix = prefix - if inactivity_period < 0: - raise ValueError("inactivity_period must be non-negative") - if previous_objects is None: - previous_objects = set() - self.inactivity_period = inactivity_period - self.min_objects = min_objects - self.previous_objects = previous_objects - self.inactivity_seconds = inactivity_seconds - self.allow_delete = allow_delete - self.aws_conn_id = aws_conn_id - self.last_activity_time: datetime | None = last_activity_time - self.verify = verify - self.polling_period_seconds = 0 - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serialize S3KeysUnchangedTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger", - { - "bucket_name": self.bucket_name, - "prefix": self.prefix, - "inactivity_period": self.inactivity_period, - "min_objects": self.min_objects, - "previous_objects": self.previous_objects, - "inactivity_seconds": self.inactivity_seconds, - "allow_delete": self.allow_delete, - "aws_conn_id": self.aws_conn_id, - "last_activity_time": self.last_activity_time, - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - """Make an asynchronous connection using S3HookAsync.""" - try: - hook = self._get_async_hook() - async with await hook.get_client_async() as client: - while True: - result = await hook.is_keys_unchanged( - client, - self.bucket_name, - self.prefix, - self.inactivity_period, - self.min_objects, - self.previous_objects, - self.inactivity_seconds, - self.allow_delete, - self.last_activity_time, - ) - if result.get("status") == "success" or result.get("status") == "error": - yield TriggerEvent(result) - elif result.get("status") == "pending": - self.previous_objects = result.get("previous_objects", set()) - self.last_activity_time = result.get("last_activity_time") - self.inactivity_seconds = result.get("inactivity_seconds", 0) - await asyncio.sleep(self.polling_period_seconds) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - - def _get_async_hook(self) -> S3HookAsync: - return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.verify) diff --git a/tests/amazon/aws/triggers/test_s3_triggers.py b/tests/amazon/aws/triggers/test_s3_triggers.py index 9cd043e02..a1ca8c2a7 100644 --- a/tests/amazon/aws/triggers/test_s3_triggers.py +++ b/tests/amazon/aws/triggers/test_s3_triggers.py @@ -1,12 +1,10 @@ import asyncio -from datetime import datetime from unittest import mock import pytest from airflow.triggers.base import TriggerEvent from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeysUnchangedTrigger, S3KeyTrigger, ) @@ -101,79 +99,3 @@ async def test_run_check_fn_success(self, mock_get_files, mock_client): generator = trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "running", "files": [{"Size": 123}]}) == actual - - -class TestS3KeysUnchangedTrigger: - def test_serialization(self): - """ - Asserts that the TaskStateTrigger correctly serializes its arguments - and classpath. - """ - trigger = S3KeysUnchangedTrigger( - bucket_name="test_bucket", - prefix="test", - inactivity_period=1, - min_objects=1, - inactivity_seconds=0, - previous_objects=None, - ) - classpath, kwargs = trigger.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger" - assert kwargs == { - "bucket_name": "test_bucket", - "prefix": "test", - "inactivity_period": 1, - "min_objects": 1, - "inactivity_seconds": 0, - "previous_objects": set(), - "allow_delete": 1, - "aws_conn_id": "aws_default", - "last_activity_time": None, - } - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - async def test_run_wait(self, mock_client): - """Test if the task is run is in trigger successfully.""" - mock_client.return_value.check_key.return_value = True - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - with mock_client: - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - assert task.done() is True - asyncio.get_event_loop().stop() - - def test_run_raise_value_error(self): - """ - Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. - """ - with pytest.raises(ValueError): - S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test", inactivity_period=-100) - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") - async def test_run_success(self, mock_is_keys_unchanged, mock_client): - """ - Test if the task is run is in triggerer successfully. - """ - mock_is_keys_unchanged.return_value = {"status": "success"} - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - generator = trigger.run() - actual = await generator.asend(None) - assert TriggerEvent({"status": "success"}) == actual - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") - async def test_run_pending(self, mock_is_keys_unchanged, mock_client): - """Test if the task is run is in triggerer successfully.""" - mock_is_keys_unchanged.return_value = {"status": "pending", "last_activity_time": datetime.now()} - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop()