diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py index 94ec86cc8..a4e902060 100644 --- a/astronomer/providers/amazon/aws/sensors/s3.py +++ b/astronomer/providers/amazon/aws/sensors/s3.py @@ -9,10 +9,9 @@ from airflow.sensors.base import BaseSensorOperator from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeysUnchangedTrigger, S3KeyTrigger, ) -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception +from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception from astronomer.providers.utils.typing_compat import Context @@ -155,72 +154,21 @@ def __init__( class S3KeysUnchangedSensorAsync(S3KeysUnchangedSensor): """ - Checks for changes in the number of objects at prefix in AWS S3 - bucket and returns True if the inactivity period has passed with no - increase in the number of objects. Note, this sensor will not behave correctly - in reschedule mode, as the state of the listed objects in the S3 bucket will - be lost between rescheduled invocations. - - :param bucket_name: Name of the S3 bucket - :param prefix: The prefix being waited on. Relative path from bucket root level. - :param aws_conn_id: a reference to the s3 connection - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. - :param inactivity_period: The total seconds of inactivity to designate - keys unchanged. Note, this mechanism is not real time and - this operator may not return until a poke_interval after this period - has passed with no additional objects sensed. - :param min_objects: The minimum number of objects needed for keys unchanged - sensor to be considered valid. - :param previous_objects: The set of object ids found during the last poke. - :param allow_delete: Should this sensor consider objects being deleted - between pokes valid behavior. If true a warning message will be logged - when this happens. If false an error will be raised. + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor`. """ - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """Defers Trigger class to check for changes in the number of objects at prefix in AWS S3""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=S3KeysUnchangedTrigger( - bucket_name=self.bucket_name, - prefix=self.prefix, - inactivity_period=self.inactivity_period, - min_objects=self.min_objects, - previous_objects=self.previous_objects, - inactivity_seconds=self.inactivity_seconds, - allow_delete=self.allow_delete, - aws_conn_id=self.aws_conn_id, - verify=self.verify, - last_activity_time=self.last_activity_time, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - return None + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] + warnings.warn( + ( + "This module is deprecated. " + "Please use `airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor` " + "and set deferrable to True instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(*args, deferrable=True, **kwargs) class S3PrefixSensorAsync(BaseSensorOperator): diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py index 30a1957a8..20f8157c0 100644 --- a/astronomer/providers/amazon/aws/triggers/s3.py +++ b/astronomer/providers/amazon/aws/triggers/s3.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from datetime import datetime from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -91,103 +90,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]: def _get_async_hook(self) -> S3HookAsync: return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) - - -class S3KeysUnchangedTrigger(BaseTrigger): - """ - S3KeyTrigger is fired as deferred class with params to run the task in trigger worker - - :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` - is not provided as a full s3:// url. - :param prefix: The prefix being waited on. Relative path from bucket root level. - :param inactivity_period: The total seconds of inactivity to designate - keys unchanged. Note, this mechanism is not real time and - this operator may not return until a poke_interval after this period - has passed with no additional objects sensed. - :param min_objects: The minimum number of objects needed for keys unchanged - sensor to be considered valid. - :param inactivity_seconds: reference to the seconds of inactivity - :param previous_objects: The set of object ids found during the last poke. - :param allow_delete: Should this sensor consider objects being deleted - :param aws_conn_id: reference to the s3 connection - :param last_activity_time: last modified or last active time - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - """ - - def __init__( - self, - bucket_name: str, - prefix: str, - inactivity_period: float = 60 * 60, - min_objects: int = 1, - inactivity_seconds: int = 0, - previous_objects: set[str] | None = None, - allow_delete: bool = True, - aws_conn_id: str = "aws_default", - last_activity_time: datetime | None = None, - verify: bool | str | None = None, - ): - super().__init__() - self.bucket_name = bucket_name - self.prefix = prefix - if inactivity_period < 0: - raise ValueError("inactivity_period must be non-negative") - if previous_objects is None: - previous_objects = set() - self.inactivity_period = inactivity_period - self.min_objects = min_objects - self.previous_objects = previous_objects - self.inactivity_seconds = inactivity_seconds - self.allow_delete = allow_delete - self.aws_conn_id = aws_conn_id - self.last_activity_time: datetime | None = last_activity_time - self.verify = verify - self.polling_period_seconds = 0 - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serialize S3KeysUnchangedTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger", - { - "bucket_name": self.bucket_name, - "prefix": self.prefix, - "inactivity_period": self.inactivity_period, - "min_objects": self.min_objects, - "previous_objects": self.previous_objects, - "inactivity_seconds": self.inactivity_seconds, - "allow_delete": self.allow_delete, - "aws_conn_id": self.aws_conn_id, - "last_activity_time": self.last_activity_time, - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - """Make an asynchronous connection using S3HookAsync.""" - try: - hook = self._get_async_hook() - async with await hook.get_client_async() as client: - while True: - result = await hook.is_keys_unchanged( - client, - self.bucket_name, - self.prefix, - self.inactivity_period, - self.min_objects, - self.previous_objects, - self.inactivity_seconds, - self.allow_delete, - self.last_activity_time, - ) - if result.get("status") == "success" or result.get("status") == "error": - yield TriggerEvent(result) - elif result.get("status") == "pending": - self.previous_objects = result.get("previous_objects", set()) - self.last_activity_time = result.get("last_activity_time") - self.inactivity_seconds = result.get("inactivity_seconds", 0) - await asyncio.sleep(self.polling_period_seconds) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - - def _get_async_hook(self) -> S3HookAsync: - return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.verify) diff --git a/tests/amazon/aws/sensors/test_s3_sensors.py b/tests/amazon/aws/sensors/test_s3_sensors.py index 8ba3047ad..4dfae42d6 100644 --- a/tests/amazon/aws/sensors/test_s3_sensors.py +++ b/tests/amazon/aws/sensors/test_s3_sensors.py @@ -7,6 +7,7 @@ from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.models.variable import Variable +from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor from airflow.utils import timezone from parameterized import parameterized @@ -17,7 +18,6 @@ S3PrefixSensorAsync, ) from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeysUnchangedTrigger, S3KeyTrigger, ) @@ -293,80 +293,12 @@ def test_soft_fail_enable(self, context): class TestS3KeysUnchangedSensorAsync: - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.defer") - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=True) - def test_s3_keys_unchanged_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - S3KeysUnchangedSensorAsync( - task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test" - ) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_keys_unchanged_sensor_check_trigger_instance(self, mock_hook, mock_poke, context): - """ - Asserts that a task is deferred and an S3KeysUnchangedTrigger will be fired - when the S3KeysUnchangedSensorAsync is executed. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeysUnchangedSensorAsync( + def test_init(self): + task = S3KeysUnchangedSensorAsync( task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test" ) - - with pytest.raises(TaskDeferred) as exc: - sensor.execute(context) - - assert isinstance( - exc.value.trigger, S3KeysUnchangedTrigger - ), "Trigger is not a S3KeysUnchangedTrigger" - - @parameterized.expand([["bucket", "test"]]) - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_keys_unchanged_sensor_execute_complete_success(self, bucket, prefix, mock_hook, mock_poke): - """ - Asserts that a task completed with success status - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeysUnchangedSensorAsync( - task_id="s3_keys_unchanged_sensor", - bucket_name=bucket, - prefix=prefix, - ) - assert sensor.execute_complete(context={}, event={"status": "success"}) is None - - @parameterized.expand([["bucket", "test"]]) - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_keys_unchanged_sensor_execute_complete_error(self, bucket, prefix, mock_hook, mock_poke): - """ - Asserts that a task is completed with error. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeysUnchangedSensorAsync( - task_id="s3_keys_unchanged_sensor", - bucket_name=bucket, - prefix=prefix, - ) - with pytest.raises(AirflowException): - sensor.execute_complete(context={}, event={"status": "error", "message": "Mocked error"}) - - @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) - def test_s3_keys_unchanged_sensor_raise_value_error(self, mock_poke): - """ - Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. - """ - with pytest.raises(ValueError): - S3KeysUnchangedSensorAsync( - task_id="s3_keys_unchanged_sensor", - bucket_name="test_bucket", - prefix="test", - inactivity_period=-100, - ) + assert isinstance(task, S3KeysUnchangedSensor) + assert task.deferrable is True class TestS3KeySizeSensorAsync(unittest.TestCase): diff --git a/tests/amazon/aws/triggers/test_s3_triggers.py b/tests/amazon/aws/triggers/test_s3_triggers.py index 9cd043e02..a1ca8c2a7 100644 --- a/tests/amazon/aws/triggers/test_s3_triggers.py +++ b/tests/amazon/aws/triggers/test_s3_triggers.py @@ -1,12 +1,10 @@ import asyncio -from datetime import datetime from unittest import mock import pytest from airflow.triggers.base import TriggerEvent from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeysUnchangedTrigger, S3KeyTrigger, ) @@ -101,79 +99,3 @@ async def test_run_check_fn_success(self, mock_get_files, mock_client): generator = trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "running", "files": [{"Size": 123}]}) == actual - - -class TestS3KeysUnchangedTrigger: - def test_serialization(self): - """ - Asserts that the TaskStateTrigger correctly serializes its arguments - and classpath. - """ - trigger = S3KeysUnchangedTrigger( - bucket_name="test_bucket", - prefix="test", - inactivity_period=1, - min_objects=1, - inactivity_seconds=0, - previous_objects=None, - ) - classpath, kwargs = trigger.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger" - assert kwargs == { - "bucket_name": "test_bucket", - "prefix": "test", - "inactivity_period": 1, - "min_objects": 1, - "inactivity_seconds": 0, - "previous_objects": set(), - "allow_delete": 1, - "aws_conn_id": "aws_default", - "last_activity_time": None, - } - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - async def test_run_wait(self, mock_client): - """Test if the task is run is in trigger successfully.""" - mock_client.return_value.check_key.return_value = True - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - with mock_client: - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - assert task.done() is True - asyncio.get_event_loop().stop() - - def test_run_raise_value_error(self): - """ - Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. - """ - with pytest.raises(ValueError): - S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test", inactivity_period=-100) - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") - async def test_run_success(self, mock_is_keys_unchanged, mock_client): - """ - Test if the task is run is in triggerer successfully. - """ - mock_is_keys_unchanged.return_value = {"status": "success"} - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - generator = trigger.run() - actual = await generator.asend(None) - assert TriggerEvent({"status": "success"}) == actual - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") - @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") - async def test_run_pending(self, mock_is_keys_unchanged, mock_client): - """Test if the task is run is in triggerer successfully.""" - mock_is_keys_unchanged.return_value = {"status": "pending", "last_activity_time": datetime.now()} - trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop()