Skip to content

Commit

Permalink
deprecate BigQueryTableExistenceSensorAsync (#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Jan 24, 2024
1 parent 8f52697 commit ae5cb97
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 177 deletions.
77 changes: 13 additions & 64 deletions astronomer/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,16 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor

from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryTableExistenceTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor):
"""
Checks for the existence of a table in Google Big Query.
:param project_id: The Google cloud project in which to look for the table.
The connection supplied to the hook must provide
access to the specified project.
:param dataset_id: The name of the dataset in which to look for the table.
storage bucket.
:param table_id: The name of the table to check the existence of.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:param delegate_to: (Removed in apache-airflow-providers-google release 10.0.0, use impersonation_chain instead)
The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service
account making the request must have domain-wide delegation enabled.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval: The interval in seconds to wait between checks table existence.
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`
and set `deferrable` param to `True` instead.
"""

def __init__(
Expand All @@ -47,7 +20,16 @@ def __init__(
polling_interval: float = 5.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
warnings.warn(
(
"This class is deprecated."
"Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, **kwargs)
# TODO: Remove once deprecated
if polling_interval:
self.poke_interval = polling_interval
Expand All @@ -58,36 +40,3 @@ def __init__(
stacklevel=2,
)
self.gcp_conn_id = gcp_conn_id

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
hook_params = {"impersonation_chain": self.impersonation_chain}
if hasattr(self, "delegate_to"): # pragma: no cover
hook_params["delegate_to"] = self.delegate_to

if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params=hook_params,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: # type: ignore[return]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
self.log.info("Sensor checks existence of table: %s", table_uri)
if event:
if event["status"] == "success":
return event["message"]
raise_error_or_skip_exception(self.soft_fail, event["message"])
92 changes: 49 additions & 43 deletions astronomer/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from __future__ import annotations

import asyncio
from typing import (
Any,
AsyncIterator,
Dict,
Optional,
Sequence,
SupportsAbs,
Tuple,
Union,
)
import warnings
from typing import Any, AsyncIterator, Sequence, SupportsAbs

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
Expand Down Expand Up @@ -38,12 +32,12 @@ class BigQueryInsertJobTrigger(BaseTrigger):
def __init__(
self,
conn_id: str,
job_id: Optional[str],
project_id: Optional[str],
dataset_id: Optional[str] = None,
table_id: Optional[str] = None,
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
job_id: str | None,
project_id: str | None,
dataset_id: str | None = None,
table_id: str | None = None,
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: float = 4.0,
):
super().__init__()
Expand All @@ -58,7 +52,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger",
Expand All @@ -74,7 +68,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Gets current job execution status and yields a TriggerEvent"""
hook = self._get_async_hook()
while True:
Expand Down Expand Up @@ -113,7 +107,7 @@ def _get_async_hook(self) -> BigQueryHookAsync:
class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
"""BigQueryCheckTrigger run on the trigger worker"""

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryCheckTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger",
Expand All @@ -128,7 +122,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Gets current job execution status and yields a TriggerEvent"""
hook = self._get_async_hook()
while True:
Expand Down Expand Up @@ -173,7 +167,7 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
"""BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class"""

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger",
Expand All @@ -189,7 +183,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Gets current job execution status and yields a TriggerEvent with response data"""
hook = self._get_async_hook()
while True:
Expand Down Expand Up @@ -248,16 +242,16 @@ def __init__(
conn_id: str,
first_job_id: str,
second_job_id: str,
project_id: Optional[str],
project_id: str | None,
table: str,
metrics_thresholds: Dict[str, int],
date_filter_column: Optional[str] = "ds",
metrics_thresholds: dict[str, int],
date_filter_column: str | None = "ds",
days_back: SupportsAbs[int] = -7,
ratio_formula: str = "max_over_min",
ignore_zero: bool = True,
dataset_id: Optional[str] = None,
table_id: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dataset_id: str | None = None,
table_id: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: float = 4.0,
):
super().__init__(
Expand All @@ -280,7 +274,7 @@ def __init__(
self.ratio_formula = ratio_formula
self.ignore_zero = ignore_zero

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryCheckTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger",
Expand All @@ -298,7 +292,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Gets current job execution status and yields a TriggerEvent"""
hook = self._get_async_hook()
while True:
Expand All @@ -325,14 +319,14 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:

# If empty list, then no records are available
if not first_records:
first_job_row: Optional[str] = None
first_job_row: str | None = None
else:
# Extract only first record from the query results
first_job_row = first_records.pop(0)

# If empty list, then no records are available
if not second_records:
second_job_row: Optional[str] = None
second_job_row: str | None = None
else:
# Extract only first record from the query results
second_job_row = second_records.pop(0)
Expand Down Expand Up @@ -391,13 +385,13 @@ def __init__(
self,
conn_id: str,
sql: str,
pass_value: Union[int, float, str],
job_id: Optional[str],
project_id: Optional[str],
pass_value: int | (float | str),
job_id: str | None,
project_id: str | None,
tolerance: Any = None,
dataset_id: Optional[str] = None,
table_id: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dataset_id: str | None = None,
table_id: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: float = 4.0,
):
super().__init__(
Expand All @@ -413,7 +407,7 @@ def __init__(
self.pass_value = pass_value
self.tolerance = tolerance

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryValueCheckTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger",
Expand All @@ -430,7 +424,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Gets current job execution status and yields a TriggerEvent"""
hook = self._get_async_hook()
while True:
Expand Down Expand Up @@ -462,6 +456,9 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
"""
Initialise the BigQuery Table Existence Trigger with needed parameters
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger` instead
:param project_id: Google Cloud Project where the job is running
:param dataset_id: The dataset ID of the requested table.
:param table_id: The table ID of the requested table.
Expand All @@ -476,17 +473,26 @@ def __init__(
dataset_id: str,
table_id: str,
gcp_conn_id: str,
hook_params: Dict[str, Any],
hook_params: dict[str, Any],
poke_interval: float = 4.0,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger`"
),
DeprecationWarning,
stacklevel=2,
)

self.dataset_id = dataset_id
self.project_id = project_id
self.table_id = table_id
self.gcp_conn_id: str = gcp_conn_id
self.poke_interval = poke_interval
self.hook_params = hook_params

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryTableExistenceTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger",
Expand All @@ -503,7 +509,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
def _get_async_hook(self) -> BigQueryTableHookAsync:
return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id, **self.hook_params)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Will run until the table exists in the Google Big Query."""
while True:
try:
Expand Down
Loading

0 comments on commit ae5cb97

Please sign in to comment.