Skip to content

Commit

Permalink
Merge pull request #902 from CityofToronto/feat/894-collisions-counts…
Browse files Browse the repository at this point in the history
…-replicator-daily-status-message

Feat/894 collisions counts replicator daily status message + dynamic DAG generation
  • Loading branch information
chmnata authored Mar 14, 2024
2 parents f79fd6e + 04d5b06 commit 406db7c
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 153 deletions.
64 changes: 0 additions & 64 deletions dags/collisions_replicator.py

This file was deleted.

53 changes: 29 additions & 24 deletions dags/common_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from psycopg2 import sql
from psycopg2 import sql, Error
from typing import Tuple
import logging
# pylint: disable=import-error
Expand Down Expand Up @@ -47,11 +47,6 @@ def copy_table(conn_id:str, table:Tuple[str, str], **context) -> None:
``schema.table``, and the destination table in the same format
``schema.table``.
"""
# push an extra failure message to be sent to Slack in case of failing
context["task_instance"].xcom_push(
"extra_msg",
f"Failed to copy `{table[0]}` to `{table[1]}`."
)
# separate tables and schemas
try:
src_schema, src_table = table[0].split(".")
Expand Down Expand Up @@ -94,25 +89,35 @@ def copy_table(conn_id:str, table:Tuple[str, str], **context) -> None:
sql.Identifier(src_schema), sql.Identifier(src_table),
sql.Identifier(dst_schema), sql.Identifier(dst_table),
)
with con, con.cursor() as cur:
# truncate the destination table
cur.execute(truncate_query)
# get the column names of the source table
cur.execute(source_columns_query, [src_schema, src_table])
src_columns = [r[0] for r in cur.fetchall()]
# copy all the data
insert_query = sql.SQL(
"INSERT INTO {}.{} ({}) SELECT {} FROM {}.{}"
).format(
sql.Identifier(dst_schema), sql.Identifier(dst_table),
sql.SQL(', ').join(map(sql.Identifier, src_columns)),
sql.SQL(', ').join(map(sql.Identifier, src_columns)),
sql.Identifier(src_schema), sql.Identifier(src_table)
)
cur.execute(insert_query)
# copy the table's comment
cur.execute(comment_query)

try:
with con, con.cursor() as cur:
# truncate the destination table
cur.execute(truncate_query)
# get the column names of the source table
cur.execute(source_columns_query, [src_schema, src_table])
src_columns = [r[0] for r in cur.fetchall()]
# copy all the data
insert_query = sql.SQL(
"INSERT INTO {}.{} ({}) SELECT {} FROM {}.{}"
).format(
sql.Identifier(dst_schema), sql.Identifier(dst_table),
sql.SQL(', ').join(map(sql.Identifier, src_columns)),
sql.SQL(', ').join(map(sql.Identifier, src_columns)),
sql.Identifier(src_schema), sql.Identifier(src_table)
)
cur.execute(insert_query)
# copy the table's comment
cur.execute(comment_query)
#catch psycopg2 errors:
except Error as e:
# push an extra failure message to be sent to Slack in case of failing
context["task_instance"].xcom_push(
key="extra_msg",
value=f"Failed to copy `{table[0]}` to `{table[1]}`: `{str(e).strip()}`."
)
raise AirflowFailException(e)

LOGGER.info(f"Successfully copied {table[0]} to {table[1]}.")

@task.short_circuit(ignore_downstream_trigger_rules=False, retries=0) #only skip immediately downstream task
Expand Down
64 changes: 0 additions & 64 deletions dags/counts_replicator.py

This file was deleted.

53 changes: 52 additions & 1 deletion dags/dag_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def task_fail_slack_alert(
# in case of a string (or the default empty string)
extra_msg_str = extra_msg

if isinstance(extra_msg_str, tuple):
if isinstance(extra_msg_str, tuple) or isinstance(extra_msg_str, list):
#recursively collapse extra_msg_str's which are in the form of a list with new lines.
extra_msg_str = '\n'.join(
['\n'.join(item) if isinstance(item, list) else item for item in extra_msg_str]
Expand Down Expand Up @@ -152,3 +152,54 @@ def get_readme_docmd(readme_path, dag_name):
doc_md_key = '<!-- ' + dag_name + '_doc_md -->'
doc_md_regex = '(?<=' + doc_md_key + '\n)[\s\S]+(?=\n' + doc_md_key + ')'
return re.findall(doc_md_regex, contents)[0]

def send_slack_msg(
context: dict,
msg: str,
attachments: Optional[list] = None,
blocks: Optional[list] = None,
use_proxy: Optional[bool] = False,
dev_mode: Optional[bool] = None
) -> Any:
"""Sends a message to Slack.
Args:
context: The calling Airflow task's context.
msg : A string message be sent to Slack.
slack_conn_id: ID of the Airflow connection with the details of the
Slack channel to send messages to.
attachments: List of dictionaries representing Slack attachments.
blocks: List of dictionaries representing Slack blocks.
use_proxy: A boolean to indicate whether to use a proxy or not. Proxy
usage is required to make the Slack webhook call on on-premises
servers (default False).
dev_mode: A boolean to indicate if working in development mode to send
Slack alerts to data_pipeline_dev instead of the regular
data_pipeline (default None, to be determined based on the location
of the file).
"""
if dev_mode or (dev_mode is None and not is_prod_mode()):
SLACK_CONN_ID = "slack_data_pipeline_dev"
else:
SLACK_CONN_ID = "slack_data_pipeline"

if use_proxy:
# get the proxy credentials from the Airflow connection ``slack``. It
# contains username and password to set the proxy <username>:<password>
proxy=(
f"http://{BaseHook.get_connection('slack').password}"
f"@{json.loads(BaseHook.get_connection('slack').extra)['url']}"
)
else:
proxy = None

slack_alert = SlackWebhookOperator(
task_id="slack_test",
slack_webhook_conn_id=SLACK_CONN_ID,
message=msg,
username="airflow",
attachments=attachments,
blocks=blocks,
proxy=proxy,
)
return slack_alert.execute(context=context)
108 changes: 108 additions & 0 deletions dags/replicators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/data/airflow/airflow_venv/bin/python3
# -*- coding: utf-8 -*-
# noqa: D415
import os
import sys
from datetime import timedelta
import pendulum
# pylint: disable=import-error
from airflow.decorators import dag, task
from airflow.models import Variable
from airflow.exceptions import AirflowFailException

# import custom operators and helper functions
repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
sys.path.insert(0, repo_path)
# pylint: disable=wrong-import-position
from dags.dag_functions import task_fail_slack_alert, send_slack_msg
# pylint: enable=import-error

def create_replicator_dag(dag_id, short_name, tables_var, conn, doc_md, default_args):
@dag(
dag_id=dag_id,
default_args=default_args,
catchup=False,
max_active_runs=5,
max_active_tasks=5,
schedule=None, #triggered externally
doc_md=doc_md,
tags=[short_name, "replicator"]
)
def replicator_DAG():
f"""The main function of the {short_name} DAG."""
from dags.common_tasks import (
wait_for_external_trigger, get_variable, copy_table
)

# Returns a list of source and destination tables stored in the given
# Airflow variable.
tables = get_variable.override(task_id="get_list_of_tables")(tables_var)

# Copies tables
copy_tables = copy_table.override(task_id="copy_tables", on_failure_callback = None).partial(conn_id=conn).expand(table=tables)

@task(
retries=0,
trigger_rule='all_done',
doc_md="""A status message to report DAG success OR any failures from the `copy_tables` task."""
)
def status_message(tables, **context):
ti = context["ti"]
failures = []
#iterate through mapped tasks to find any failure messages
for m_i in range(0, len(tables)):
failure_msg = ti.xcom_pull(key="extra_msg", task_ids="copy_tables", map_indexes=m_i)
if failure_msg is not None:
failures.append(failure_msg)
if failures == []:
send_slack_msg(
context=context,
msg=f"`{dag_id}` DAG succeeded :white_check_mark:"
)
else: #add details of failures to task_fail_slack_alert
failure_extra_msg = ['One or more tables failed to copy:', failures]
context.get("task_instance").xcom_push(key="extra_msg", value=failure_extra_msg)
raise AirflowFailException('One or more tables failed to copy.')

# Waits for an external trigger
wait_for_external_trigger() >> tables >> copy_tables >> status_message(tables=tables)

generated_dag = replicator_DAG()

return generated_dag

#get replicator details from airflow variable
REPLICATORS = Variable.get('replicators', deserialize_json=True)

#generate replicator DAGs from dict
for replicator, dag_items in REPLICATORS.items():
DAG_NAME = dag_items['dag_name']
DAG_OWNERS = Variable.get("dag_owners", deserialize_json=True).get(DAG_NAME, ["Unknown"])

default_args = {
"owner": ",".join(DAG_OWNERS),
"depends_on_past": False,
"start_date": pendulum.datetime(2023, 10, 31, tz="America/Toronto"),
"email_on_failure": False,
"retries": 3,
"retry_delay": timedelta(minutes=60),
"on_failure_callback": task_fail_slack_alert,
}

doc_md = f"""### The Daily {replicator} Replicator DAG
This DAG runs daily to copy MOVE's {replicator} tables from the ``move_staging``
schema, which is updated by the MOVE's ``bigdata_replicator`` DAG, to the
``{replicator}`` schema. This DAG runs only when it is triggered by the MOVE's
DAG."""

globals()[DAG_NAME] = (
create_replicator_dag(
dag_id=DAG_NAME,
short_name=replicator,
tables_var=dag_items['tables'],
conn=dag_items['conn'],
doc_md=doc_md,
default_args=default_args
)
)

0 comments on commit 406db7c

Please sign in to comment.