Skip to content

Commit

Permalink
Deprecate EmrJobFlowSensorAsync (#1406)
Browse files Browse the repository at this point in the history
* feat(amazon): deprecate EmrJobFlowSensorAsync
* feat(amazon): fallback poll_interval to poke_interval for EmrJobFlowSensorAsync
* feat(amazon): add deprecation warning to emr hooks and triggerer
  • Loading branch information
Lee-W authored Jan 17, 2024
1 parent 1fa825f commit b697ecc
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 145 deletions.
9 changes: 9 additions & 0 deletions astronomer/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Dict, Optional

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -146,6 +147,14 @@ class EmrJobFlowHookAsync(AwsBaseHookAsync):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use :class: `~airflow.providers.amazon.aws.hooks.emr.EmrHook`."
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)

Expand Down
89 changes: 21 additions & 68 deletions astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

from airflow import AirflowException
from airflow.exceptions import AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import (
EmrContainerSensor,
EmrJobFlowSensor,
EmrStepSensor,
)

from astronomer.providers.amazon.aws.triggers.emr import (
EmrJobFlowSensorTrigger,
)
from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class EmrContainerSensorAsync(EmrContainerSensor):
"""
Expand All @@ -26,6 +16,16 @@ class EmrContainerSensorAsync(EmrContainerSensor):
"""

def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
poll_interval = kwargs.pop("poll_interval")
if poll_interval:
self.poke_interval = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)

warnings.warn(
(
"This module is deprecated. "
Expand Down Expand Up @@ -59,65 +59,18 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]

class EmrJobFlowSensorAsync(EmrJobFlowSensor):
"""
Async EMR Job flow sensor polls for the cluster state until it reaches
any of the target states.
If it fails the sensor errors, failing the task.
With the default target states, sensor waits cluster to be terminated.
When target_states is set to ['RUNNING', 'WAITING'] sensor waits
until job flow to be ready (after 'STARTING' and 'BOOTSTRAPPING' states)
:param job_flow_id: job_flow_id to check the state of cluster
:param target_states: the target states, sensor waits until
job flow reaches any of these states
:param failed_states: the failure states, sensor fails when
job flow reaches any of these states
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor`.
"""

def __init__(
self,
*,
poll_interval: float = 5,
**kwargs: Any,
):
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
emr_client = self.hook.conn
self.log.info("Poking cluster %s", self.job_flow_id)
response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
state = response["Cluster"]["Status"]["State"]
self.log.info("Job flow currently %s", state)

if state == "TERMINATED":
return None

if state == "TERMINATED_WITH_ERRORS":
if self.soft_fail: # pragma: no cover
AirflowSkipException(f"EMR job failed: {self.failure_message_from_response(response)}")
raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}")

self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=EmrJobFlowSensorTrigger(
job_flow_id=self.job_flow_id,
aws_conn_id=self.aws_conn_id,
target_states=self.target_states,
failed_states=self.failed_states,
poll_interval=self.poll_interval,
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor` "
"and set deferrable to True instead."
),
method_name="execute_complete",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
return None
return super().__init__(*args, deferrable=True, **kwargs)
23 changes: 17 additions & 6 deletions astronomer/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, Dict, Iterable, Optional, Tuple
import warnings
from typing import Any, AsyncIterator, Iterable

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -25,7 +28,7 @@ def __init__(
job_id: str,
aws_conn_id: str = "aws_default",
poll_interval: int = 10,
max_tries: Optional[int] = None,
max_tries: int | None = None,
**kwargs: Any,
):
self.virtual_cluster_id = virtual_cluster_id
Expand Down Expand Up @@ -54,17 +57,25 @@ def __init__(
job_flow_id: str,
aws_conn_id: str,
poll_interval: float,
target_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
target_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use :class: `~airflow.providers.amazon.aws.triggers.emr.EmrTerminateJobFlowTrigger."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes EmrJobFlowSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrJobFlowSensorTrigger",
Expand All @@ -77,7 +88,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to EMR container, polls for the target job state"""
hook = EmrJobFlowHookAsync(aws_conn_id=self.aws_conn_id)
try:
Expand Down
79 changes: 8 additions & 71 deletions tests/amazon/aws/sensors/test_emr_sensors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from unittest import mock
from unittest.mock import PropertyMock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.emr import (
EmrStepSensor,
)
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor, EmrJobFlowSensor, EmrStepSensor

from astronomer.providers.amazon.aws.sensors.emr import (
EmrContainerSensorAsync,
EmrJobFlowSensorAsync,
EmrStepSensorAsync,
)
from astronomer.providers.amazon.aws.triggers.emr import (
EmrJobFlowSensorTrigger,
)

TASK_ID = "test_emr_container_sensor"
VIRTUAL_CLUSTER_ID = "test_cluster_1"
Expand All @@ -38,7 +28,7 @@ def test_init(self):
max_retries=1,
aws_conn_id=AWS_CONN_ID,
)
assert isinstance(task, EmrContainerSensorAsync)
assert isinstance(task, EmrContainerSensor)
assert task.deferrable is True


Expand All @@ -54,63 +44,10 @@ def test_init(self):


class TestEmrJobFlowSensorAsync:
TASK = EmrJobFlowSensorAsync(
task_id=TASK_ID,
job_flow_id=JOB_ID,
)

@mock.patch(f"{MODULE}.EmrJobFlowSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrJobFlowSensorAsync.hook", new_callable=PropertyMock)
def test_emr_job_flow_sensor_async_finish_before_deferred(self, mock_hook, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
mock_hook.return_value.conn.describe_cluster.return_value = {
"Cluster": {"Status": {"State": "TERMINATED"}}
}
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.EmrJobFlowSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrJobFlowSensorAsync.hook", new_callable=PropertyMock)
def test_emr_job_flow_sensor_async_failed_before_deferred(self, mock_hook, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
mock_hook.return_value.conn.describe_cluster.return_value = {
"Cluster": {"Status": {"State": "TERMINATED_WITH_ERRORS"}}
}
with pytest.raises(AirflowException):
self.TASK.execute(context)
assert not mock_defer.called

@pytest.mark.parametrize("status", ("STARTING", "BOOTSTRAPPING", "RUNNING", "WAITING"))
@mock.patch(f"{MODULE}.EmrJobFlowSensorAsync.hook", new_callable=PropertyMock)
def test_emr_job_flow_sensor_async(self, mock_hook, status, context):
"""
Asserts that a task is deferred and a EmrJobFlowSensorTrigger will be fired
when the EmrJobFlowSensorAsync is executed.
"""
mock_hook.return_value.conn.describe_cluster.return_value = {"Cluster": {"Status": {"State": status}}}
with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(
exc.value.trigger, EmrJobFlowSensorTrigger
), "Trigger is not a EmrJobFlowSensorTrigger"

def test_emr_flow_sensor_async_execute_failure(self, context):
"""Test EMR flow sensor with an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_emr_job_flow_sensor_async_execute_complete(self):
"""Asserts that logging occurs as expected"""

assert (
self.TASK.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
is None
def test_init(self):
task = EmrJobFlowSensorAsync(
task_id=TASK_ID,
job_flow_id=JOB_ID,
)

def test_emr_job_flow_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""

assert self.TASK.execute_complete(context=None, event=None) is None
assert isinstance(task, EmrJobFlowSensor)
assert task.deferrable is True

0 comments on commit b697ecc

Please sign in to comment.