Skip to content

Commit

Permalink
Deprecate WasbBlobSensorAsync (#1435)
Browse files Browse the repository at this point in the history
Deprecate WasbBlobSensorAsync and proxy it to its Airflow OSS
provider's counterpart

related: #1412
  • Loading branch information
pankajkoti authored Jan 19, 2024
1 parent 2cb1e6b commit 52aaa36
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 114 deletions.
15 changes: 11 additions & 4 deletions astronomer/providers/microsoft/azure/hooks/wasb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the Azure WASB hook's asynchronous implementation."""
from __future__ import annotations

import warnings
from typing import Any, Union

from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
Expand All @@ -14,10 +15,8 @@

class WasbHookAsync(WasbHook):
"""
An async hook that connects to Azure WASB to perform operations.
:param wasb_conn_id: reference to the :ref:`wasb connection <howto/connection:wasb>`
:param public_read: whether an anonymous public read access should be used. default is False
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.hooks.wasb.WasbHook` instead.
"""

def __init__(
Expand All @@ -26,6 +25,14 @@ def __init__(
public_read: bool = False,
) -> None:
"""Initialize the hook instance."""
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.microsoft.azure.hooks.wasb.WasbHook` instead."
),
DeprecationWarning,
stacklevel=2,
)
self.conn_id = wasb_conn_id
self.public_read = public_read
self.blob_service_client: BlobServiceClient = self.get_conn()
Expand Down
64 changes: 16 additions & 48 deletions astronomer/providers/microsoft/azure/sensors/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,75 +7,43 @@
WasbPrefixSensor,
)

from astronomer.providers.microsoft.azure.triggers.wasb import (
WasbBlobSensorTrigger,
WasbPrefixSensorTrigger,
)
from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class WasbBlobSensorAsync(WasbBlobSensor):
"""
Polls asynchronously for the existence of a blob in a WASB container.
:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poll_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
This class is deprecated.
Use :class: `~airflow.providers.microsoft.azure.sensors.wasb.WasbBlobSensor` instead
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
container_name: str,
blob_name: str,
wasb_conn_id: str = "wasb_default",
public_read: bool = False,
*args: Any,
poll_interval: float = 5.0,
**kwargs: Any,
):
self.container_name = container_name
self.blob_name = blob_name
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.microsoft.azure.sensors.wasb.WasbBlobSensor` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
kwargs["poke_interval"] = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(container_name=container_name, blob_name=blob_name, **kwargs)
self.wasb_conn_id = wasb_conn_id
self.public_read = public_read

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbBlobSensorTrigger(
container_name=self.container_name,
blob_name=self.blob_name,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
super().__init__(*args, deferrable=True, **kwargs)


class WasbPrefixSensorAsync(WasbPrefixSensor):
Expand Down
19 changes: 11 additions & 8 deletions astronomer/providers/microsoft/azure/triggers/wasb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -8,14 +9,8 @@

class WasbBlobSensorTrigger(BaseTrigger):
"""
WasbBlobSensorTrigger is fired as deferred class with params to run the task in trigger worker to check for
existence of the given blob in the provided container.
:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger` instead.
"""

def __init__(
Expand All @@ -26,6 +21,14 @@ def __init__(
public_read: bool = False,
poke_interval: float = 5.0,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbBlobSensorTrigger` instead"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.container_name = container_name
self.blob_name = blob_name
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ google =
http =
apache-airflow-providers-http
microsoft.azure =
apache-airflow-providers-microsoft-azure
apache-airflow-providers-microsoft-azure>=8.5.1
sftp =
apache-airflow-providers-sftp
asyncssh>=2.12.0
Expand Down Expand Up @@ -128,7 +128,7 @@ all =
apache-airflow-providers-http
apache-airflow-providers-snowflake
apache-airflow-providers-sftp
apache-airflow-providers-microsoft-azure
apache-airflow-providers-microsoft-azure>=8.5.1
asyncssh>=2.12.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
apache-airflow-providers-dbt-cloud>=2.1.0
Expand Down
63 changes: 11 additions & 52 deletions tests/microsoft/azure/sensors/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor

from astronomer.providers.microsoft.azure.sensors.wasb import (
WasbBlobSensorAsync,
WasbPrefixSensorAsync,
)
from astronomer.providers.microsoft.azure.triggers.wasb import (
WasbBlobSensorTrigger,
WasbPrefixSensorTrigger,
)
from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger
from tests.utils.airflow_util import create_context

TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers_team.txt"
Expand All @@ -21,54 +19,15 @@


class TestWasbBlobSensorAsync:
SENSOR = WasbBlobSensorAsync(
task_id="wasb_blob_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
)

@mock.patch(f"{MODULE}.WasbBlobSensorAsync.defer")
@mock.patch(f"{MODULE}.WasbBlobSensorAsync.poke", return_value=True)
def test_wasb_blob_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.SENSOR.execute(create_context(self.SENSOR))
assert not mock_defer.called

@mock.patch(f"{MODULE}.WasbBlobSensorAsync.poke", return_value=False)
def test_wasb_blob_sensor_async(self, mock_poke):
"""Assert execute method defer for wasb blob sensor"""

with pytest.raises(TaskDeferred) as exc:
self.SENSOR.execute(create_context(self.SENSOR))
assert isinstance(exc.value.trigger, WasbBlobSensorTrigger), "Trigger is not a WasbBlobSensorTrigger"

@pytest.mark.parametrize(
"event",
[{"status": "success", "message": "Job completed"}],
)
def test_wasb_blob_sensor_execute_complete_success(self, event):
"""Assert execute_complete log success message when trigger fire with target status."""

with mock.patch.object(self.SENSOR.log, "info") as mock_log_info:
self.SENSOR.execute_complete(context={}, event=event)
mock_log_info.assert_called_with(event["message"])

def test_wasb_blob_sensor_execute_complete_failure(self):
"""Assert execute_complete method raises an exception when the triggerer fires an error event."""

with pytest.raises(AirflowException):
self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""})

def test_poll_interval_deprecation_warning_wasb_blob(self):
"""Test DeprecationWarning for WasbBlobSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
WasbBlobSensorAsync(
task_id="wasb_blob_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
poll_interval=5.0,
)
def test_init(self):
task = WasbBlobSensorAsync(
task_id="wasb_blob_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
)

assert isinstance(task, WasbBlobSensor)
assert task.deferrable is True


class TestWasbPrefixSensorAsync:
Expand Down

0 comments on commit 52aaa36

Please sign in to comment.