diff --git a/dags/collisions_replicator.py b/dags/collisions_replicator.py deleted file mode 100644 index 99b32dac6..000000000 --- a/dags/collisions_replicator.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/data/airflow/airflow_venv/bin/python3 -# -*- coding: utf-8 -*- -# noqa: D415 -r"""### The Daily Collision Replicator DAG - -This DAG runs daily to copy MOVE's collisions tables from the ``move_staging`` -schema, which is updated by the MOVE's ``bigdata_replicator`` DAG, to the -``collisions`` schema. This DAG runs only when it is triggered by the MOVE's -DAG. -""" -import os -import sys -from datetime import timedelta -from functools import partial -import pendulum -# pylint: disable=import-error -from airflow.decorators import dag -from airflow.models import Variable - -# import custom operators and helper functions -repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) -sys.path.insert(0, repo_path) -# pylint: disable=wrong-import-position -from dags.dag_functions import task_fail_slack_alert -# pylint: enable=import-error - -DAG_NAME = "collisions_replicator" -DAG_OWNERS = Variable.get("dag_owners", deserialize_json=True).get(DAG_NAME, ["Unknown"]) - -default_args = { - "owner": ",".join(DAG_OWNERS), - "depends_on_past": False, - "start_date": pendulum.datetime(2023, 10, 31, tz="America/Toronto"), - "email_on_failure": False, - "retries": 3, - "retry_delay": timedelta(minutes=60), - "on_failure_callback": task_fail_slack_alert, -} - -@dag( - dag_id=DAG_NAME, - default_args=default_args, - catchup=False, - max_active_runs=5, - max_active_tasks=5, - schedule=None, - doc_md=__doc__, - tags=["collisions"] -) -def collisions_replicator(): - """The main function of the collisions DAG.""" - from dags.common_tasks import ( - wait_for_external_trigger, get_variable, copy_table - ) - - # Returns a list of source and destination tables stored in the given - # Airflow variable. - tables = get_variable.override(task_id="get_list_of_tables")("collisions_tables") - # Waits for an external trigger - wait_for_external_trigger() >> tables - # Copies tables - copy_table.override(task_id="copy_tables").partial(conn_id="collisions_bot").expand(table=tables) - -collisions_replicator() diff --git a/dags/common_tasks.py b/dags/common_tasks.py index b85768191..fc45448f7 100644 --- a/dags/common_tasks.py +++ b/dags/common_tasks.py @@ -1,5 +1,5 @@ -from psycopg2 import sql +from psycopg2 import sql, Error from typing import Tuple import logging # pylint: disable=import-error @@ -47,11 +47,6 @@ def copy_table(conn_id:str, table:Tuple[str, str], **context) -> None: ``schema.table``, and the destination table in the same format ``schema.table``. """ - # push an extra failure message to be sent to Slack in case of failing - context["task_instance"].xcom_push( - "extra_msg", - f"Failed to copy `{table[0]}` to `{table[1]}`." - ) # separate tables and schemas try: src_schema, src_table = table[0].split(".") @@ -94,25 +89,35 @@ def copy_table(conn_id:str, table:Tuple[str, str], **context) -> None: sql.Identifier(src_schema), sql.Identifier(src_table), sql.Identifier(dst_schema), sql.Identifier(dst_table), ) - with con, con.cursor() as cur: - # truncate the destination table - cur.execute(truncate_query) - # get the column names of the source table - cur.execute(source_columns_query, [src_schema, src_table]) - src_columns = [r[0] for r in cur.fetchall()] - # copy all the data - insert_query = sql.SQL( - "INSERT INTO {}.{} ({}) SELECT {} FROM {}.{}" - ).format( - sql.Identifier(dst_schema), sql.Identifier(dst_table), - sql.SQL(', ').join(map(sql.Identifier, src_columns)), - sql.SQL(', ').join(map(sql.Identifier, src_columns)), - sql.Identifier(src_schema), sql.Identifier(src_table) - ) - cur.execute(insert_query) - # copy the table's comment - cur.execute(comment_query) + try: + with con, con.cursor() as cur: + # truncate the destination table + cur.execute(truncate_query) + # get the column names of the source table + cur.execute(source_columns_query, [src_schema, src_table]) + src_columns = [r[0] for r in cur.fetchall()] + # copy all the data + insert_query = sql.SQL( + "INSERT INTO {}.{} ({}) SELECT {} FROM {}.{}" + ).format( + sql.Identifier(dst_schema), sql.Identifier(dst_table), + sql.SQL(', ').join(map(sql.Identifier, src_columns)), + sql.SQL(', ').join(map(sql.Identifier, src_columns)), + sql.Identifier(src_schema), sql.Identifier(src_table) + ) + cur.execute(insert_query) + # copy the table's comment + cur.execute(comment_query) + #catch psycopg2 errors: + except Error as e: + # push an extra failure message to be sent to Slack in case of failing + context["task_instance"].xcom_push( + key="extra_msg", + value=f"Failed to copy `{table[0]}` to `{table[1]}`: `{str(e).strip()}`." + ) + raise AirflowFailException(e) + LOGGER.info(f"Successfully copied {table[0]} to {table[1]}.") @task.short_circuit(ignore_downstream_trigger_rules=False, retries=0) #only skip immediately downstream task diff --git a/dags/counts_replicator.py b/dags/counts_replicator.py deleted file mode 100644 index 4196cf443..000000000 --- a/dags/counts_replicator.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/data/airflow/airflow_venv/bin/python3 -# -*- coding: utf-8 -*- -# noqa: D415 -r"""### The Daily counts Replicator DAG - -This DAG runs daily to copy MOVE's counts tables from the ``move_staging`` -schema, which is updated by the MOVE's ``bigdata_replicator`` DAG, to the -``traffic`` schema. This DAG runs only when it is triggered by the MOVE's -DAG. -""" -import os -import sys -from datetime import timedelta -from functools import partial -import pendulum -# pylint: disable=import-error -from airflow.decorators import dag -from airflow.models import Variable - -# import custom operators and helper functions -repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) -sys.path.insert(0, repo_path) -# pylint: disable=wrong-import-position -from dags.dag_functions import task_fail_slack_alert -# pylint: enable=import-error - -DAG_NAME = "counts_replicator" -DAG_OWNERS = Variable.get("dag_owners", deserialize_json=True).get(DAG_NAME, ["Unknown"]) - -default_args = { - "owner": ",".join(DAG_OWNERS), - "depends_on_past": False, - "start_date": pendulum.datetime(2023, 10, 31, tz="America/Toronto"), - "email_on_failure": False, - "retries": 3, - "retry_delay": timedelta(minutes=60), - "on_failure_callback": task_fail_slack_alert, -} - -@dag( - dag_id=DAG_NAME, - default_args=default_args, - catchup=False, - max_active_runs=5, - max_active_tasks=5, - schedule=None, - doc_md=__doc__, - tags=["counts"] -) -def counts_replicator(): - """The main function of the counts DAG.""" - from dags.common_tasks import ( - wait_for_external_trigger, get_variable, copy_table - ) - - # Returns a list of source and destination tables stored in the given - # Airflow variable. - tables = get_variable.override(task_id="get_list_of_tables")("counts_tables") - # Waits for an external trigger - wait_for_external_trigger() >> tables - # Copies tables - copy_table.override(task_id="copy_tables").partial(conn_id="traffic_bot").expand(table=tables) - -counts_replicator() diff --git a/dags/dag_functions.py b/dags/dag_functions.py index 54390833f..d2cc751d3 100644 --- a/dags/dag_functions.py +++ b/dags/dag_functions.py @@ -98,7 +98,7 @@ def task_fail_slack_alert( # in case of a string (or the default empty string) extra_msg_str = extra_msg - if isinstance(extra_msg_str, tuple): + if isinstance(extra_msg_str, tuple) or isinstance(extra_msg_str, list): #recursively collapse extra_msg_str's which are in the form of a list with new lines. extra_msg_str = '\n'.join( ['\n'.join(item) if isinstance(item, list) else item for item in extra_msg_str] @@ -152,3 +152,54 @@ def get_readme_docmd(readme_path, dag_name): doc_md_key = '' doc_md_regex = '(?<=' + doc_md_key + '\n)[\s\S]+(?=\n' + doc_md_key + ')' return re.findall(doc_md_regex, contents)[0] + +def send_slack_msg( + context: dict, + msg: str, + attachments: Optional[list] = None, + blocks: Optional[list] = None, + use_proxy: Optional[bool] = False, + dev_mode: Optional[bool] = None +) -> Any: + """Sends a message to Slack. + + Args: + context: The calling Airflow task's context. + msg : A string message be sent to Slack. + slack_conn_id: ID of the Airflow connection with the details of the + Slack channel to send messages to. + attachments: List of dictionaries representing Slack attachments. + blocks: List of dictionaries representing Slack blocks. + use_proxy: A boolean to indicate whether to use a proxy or not. Proxy + usage is required to make the Slack webhook call on on-premises + servers (default False). + dev_mode: A boolean to indicate if working in development mode to send + Slack alerts to data_pipeline_dev instead of the regular + data_pipeline (default None, to be determined based on the location + of the file). + """ + if dev_mode or (dev_mode is None and not is_prod_mode()): + SLACK_CONN_ID = "slack_data_pipeline_dev" + else: + SLACK_CONN_ID = "slack_data_pipeline" + + if use_proxy: + # get the proxy credentials from the Airflow connection ``slack``. It + # contains username and password to set the proxy : + proxy=( + f"http://{BaseHook.get_connection('slack').password}" + f"@{json.loads(BaseHook.get_connection('slack').extra)['url']}" + ) + else: + proxy = None + + slack_alert = SlackWebhookOperator( + task_id="slack_test", + slack_webhook_conn_id=SLACK_CONN_ID, + message=msg, + username="airflow", + attachments=attachments, + blocks=blocks, + proxy=proxy, + ) + return slack_alert.execute(context=context) \ No newline at end of file diff --git a/dags/replicators.py b/dags/replicators.py new file mode 100644 index 000000000..ba47c2be8 --- /dev/null +++ b/dags/replicators.py @@ -0,0 +1,108 @@ +#!/data/airflow/airflow_venv/bin/python3 +# -*- coding: utf-8 -*- +# noqa: D415 +import os +import sys +from datetime import timedelta +import pendulum +# pylint: disable=import-error +from airflow.decorators import dag, task +from airflow.models import Variable +from airflow.exceptions import AirflowFailException + +# import custom operators and helper functions +repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +sys.path.insert(0, repo_path) +# pylint: disable=wrong-import-position +from dags.dag_functions import task_fail_slack_alert, send_slack_msg +# pylint: enable=import-error + +def create_replicator_dag(dag_id, short_name, tables_var, conn, doc_md, default_args): + @dag( + dag_id=dag_id, + default_args=default_args, + catchup=False, + max_active_runs=5, + max_active_tasks=5, + schedule=None, #triggered externally + doc_md=doc_md, + tags=[short_name, "replicator"] + ) + def replicator_DAG(): + f"""The main function of the {short_name} DAG.""" + from dags.common_tasks import ( + wait_for_external_trigger, get_variable, copy_table + ) + + # Returns a list of source and destination tables stored in the given + # Airflow variable. + tables = get_variable.override(task_id="get_list_of_tables")(tables_var) + + # Copies tables + copy_tables = copy_table.override(task_id="copy_tables", on_failure_callback = None).partial(conn_id=conn).expand(table=tables) + + @task( + retries=0, + trigger_rule='all_done', + doc_md="""A status message to report DAG success OR any failures from the `copy_tables` task.""" + ) + def status_message(tables, **context): + ti = context["ti"] + failures = [] + #iterate through mapped tasks to find any failure messages + for m_i in range(0, len(tables)): + failure_msg = ti.xcom_pull(key="extra_msg", task_ids="copy_tables", map_indexes=m_i) + if failure_msg is not None: + failures.append(failure_msg) + if failures == []: + send_slack_msg( + context=context, + msg=f"`{dag_id}` DAG succeeded :white_check_mark:" + ) + else: #add details of failures to task_fail_slack_alert + failure_extra_msg = ['One or more tables failed to copy:', failures] + context.get("task_instance").xcom_push(key="extra_msg", value=failure_extra_msg) + raise AirflowFailException('One or more tables failed to copy.') + + # Waits for an external trigger + wait_for_external_trigger() >> tables >> copy_tables >> status_message(tables=tables) + + generated_dag = replicator_DAG() + + return generated_dag + +#get replicator details from airflow variable +REPLICATORS = Variable.get('replicators', deserialize_json=True) + +#generate replicator DAGs from dict +for replicator, dag_items in REPLICATORS.items(): + DAG_NAME = dag_items['dag_name'] + DAG_OWNERS = Variable.get("dag_owners", deserialize_json=True).get(DAG_NAME, ["Unknown"]) + + default_args = { + "owner": ",".join(DAG_OWNERS), + "depends_on_past": False, + "start_date": pendulum.datetime(2023, 10, 31, tz="America/Toronto"), + "email_on_failure": False, + "retries": 3, + "retry_delay": timedelta(minutes=60), + "on_failure_callback": task_fail_slack_alert, + } + + doc_md = f"""### The Daily {replicator} Replicator DAG + + This DAG runs daily to copy MOVE's {replicator} tables from the ``move_staging`` + schema, which is updated by the MOVE's ``bigdata_replicator`` DAG, to the + ``{replicator}`` schema. This DAG runs only when it is triggered by the MOVE's + DAG.""" + + globals()[DAG_NAME] = ( + create_replicator_dag( + dag_id=DAG_NAME, + short_name=replicator, + tables_var=dag_items['tables'], + conn=dag_items['conn'], + doc_md=doc_md, + default_args=default_args + ) + ) \ No newline at end of file