Skip to content

Commit

Permalink
Refactors dbt-running functionality into a function, updates taskflow…
Browse files Browse the repository at this point in the history
…s, and addresses the dataset change that highlighted the issue.
  • Loading branch information
MattTriano committed Nov 14, 2024
1 parent cecfb52 commit ff2e9e6
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 120 deletions.
29 changes: 29 additions & 0 deletions airflow/dags/cc_utils/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from logging import Logger
import subprocess
import re

from cc_utils.utils import log_as_info


class DbtExecutionError(Exception):
def __init__(self, msg):
super().__init__(msg)


def run_dbt_dataset_transformations(
dataset_name: str, task_logger: Logger, schema: str = "data_raw", dbt_project: str = "re_dbt"
) -> bool:
dbt_cmd = f"""cd /opt/airflow/dbt && \
dbt --warn-error-options \
'{{"include": "all", "exclude": [UnusedResourceConfigPath]}}' \
run --select {dbt_project}.{schema}.{dataset_name}"""
log_as_info(task_logger, f"dbt run command: {dbt_cmd}")
subproc_output = subprocess.run(dbt_cmd, shell=True, capture_output=True, text=True)
raise_exception = False
for el in subproc_output.stdout.split("\n"):
log_as_info(task_logger, f"{el}")
if re.search("(\\d* of \\d* ERROR)", el):
raise DbtExecutionError(
msg=f"dbt model failed. Review the dbt output.\n{subproc_output.stdout}",
)
return True
122 changes: 61 additions & 61 deletions airflow/dags/tasks/census_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from cc_utils.file_factory import (
make_dbt_data_raw_model_file,
)
from cc_utils.transform import run_dbt_dataset_transformations
from cc_utils.utils import log_as_info
from cc_utils.validation import (
run_checkpoint,
check_if_checkpoint_exists,
Expand Down Expand Up @@ -57,8 +59,9 @@ def ingest_api_dataset_freshness_check(
ingested_api_datasets_df = execute_result_returning_orm_query(
engine=engine, select_query=insert_statement
)
task_logger.info(
f"Max census_api_metadata id value after ingestion: {ingested_api_datasets_df['id'].max()}"
log_as_info(
task_logger,
f"Max census_api_metadata id value after ingestion: {ingested_api_datasets_df['id'].max()}",
)
return ingested_api_datasets_df

Expand All @@ -69,11 +72,11 @@ def get_source_dataset_metadata(
) -> CensusAPIDatasetSource:
dataset_base_url = census_dataset.api_call_obj.dataset_base_url
dataset_source = CensusAPIDatasetSource(dataset_base_url=dataset_base_url)
task_logger.info(f"Census dataset {dataset_base_url} metadata details:")
task_logger.info(f" - Dataset variables: {len(dataset_source.variables_df)}")
task_logger.info(f" - Dataset geographies: {len(dataset_source.geographies_df)}")
task_logger.info(f" - Dataset groups: {len(dataset_source.groups_df)}")
task_logger.info(f" - Dataset tags: {len(dataset_source.tags_df)}")
log_as_info(task_logger, f"Census dataset {dataset_base_url} metadata details:")
log_as_info(task_logger, f" - Dataset variables: {len(dataset_source.variables_df)}")
log_as_info(task_logger, f" - Dataset geographies: {len(dataset_source.geographies_df)}")
log_as_info(task_logger, f" - Dataset groups: {len(dataset_source.groups_df)}")
log_as_info(task_logger, f" - Dataset tags: {len(dataset_source.tags_df)}")
return dataset_source


Expand All @@ -95,11 +98,11 @@ def record_source_freshness_check(
"time_of_check": [dataset_source.time_of_check],
}
)
task_logger.info(f"dataset name: {freshness_check_record['dataset_name']}")
task_logger.info(
f"dataset last_modified: {freshness_check_record['source_data_last_modified']}"
log_as_info(task_logger, f"dataset name: {freshness_check_record['dataset_name']}")
log_as_info(
task_logger, f"dataset last_modified: {freshness_check_record['source_data_last_modified']}"
)
task_logger.info(f"dataset time_of_check: {freshness_check_record['time_of_check']}")
log_as_info(task_logger, f"dataset time_of_check: {freshness_check_record['time_of_check']}")

engine = get_pg_engine(conn_id=conn_id)
metadata_table = get_reflected_db_table(
Expand Down Expand Up @@ -144,8 +147,9 @@ def get_latest_local_freshness_check(
)
latest_dataset_check_df = latest_dataset_check_df.drop(columns=["rn"])
source_last_modified = latest_dataset_check_df["source_data_last_modified"].max()
task_logger.info(
f"Last_modified datetime of latest local dataset update: {source_last_modified}."
log_as_info(
task_logger,
f"Last_modified datetime of latest local dataset update: {source_last_modified}.",
)
return latest_dataset_check_df

Expand Down Expand Up @@ -195,13 +199,13 @@ def fresher_source_data_available(
) -> str:
dataset_in_local_dwh = len(freshness_check.local_freshness) > 0

task_logger.info(f"Dataset in local dwh: {dataset_in_local_dwh}")
task_logger.info(f"freshness_check.local_freshness: {freshness_check.local_freshness}")
log_as_info(task_logger, f"Dataset in local dwh: {dataset_in_local_dwh}")
log_as_info(task_logger, f"freshness_check.local_freshness: {freshness_check.local_freshness}")
if dataset_in_local_dwh:
local_last_modified = freshness_check.local_freshness["source_data_last_modified"].max()
task_logger.info(f"Local dataset last modified: {local_last_modified}")
log_as_info(task_logger, f"Local dataset last modified: {local_last_modified}")
source_last_modified = freshness_check.source_freshness["source_data_last_modified"].max()
task_logger.info(f"Source dataset last modified: {source_last_modified}")
log_as_info(task_logger, f"Source dataset last modified: {source_last_modified}")
local_dataset_is_fresh = local_last_modified >= source_last_modified
if local_dataset_is_fresh:
return "update_census_table.local_data_is_fresh"
Expand All @@ -223,10 +227,10 @@ def ingest_dataset_metadata(
) -> str:
metadata_df = freshness_check.dataset_source.metadata_catalog_df.copy()
metadata_df["time_of_check"] = freshness_check.source_freshness["time_of_check"].max()
task_logger.info(f"Dataset metadata columns:")
log_as_info(task_logger, f"Dataset metadata columns:")
for col in metadata_df.columns:
task_logger.info(f" {col}")
task_logger.info(f"{metadata_df.T}")
log_as_info(task_logger, f" {col}")
log_as_info(task_logger, f"{metadata_df.T}")
engine = get_pg_engine(conn_id=conn_id)
metadata_table = get_reflected_db_table(
engine=engine, table_name="census_api_dataset_metadata", schema_name="metadata"
Expand Down Expand Up @@ -261,7 +265,7 @@ def ingest_dataset_variables_metadata(
.returning(metadata_table)
)
ingested_df = execute_result_returning_orm_query(engine=engine, select_query=insert_statement)
task_logger.info(f"Variables ingested: {len(ingested_df)}")
log_as_info(task_logger, f"Variables ingested: {len(ingested_df)}")
return "success"


Expand All @@ -286,7 +290,7 @@ def ingest_dataset_geographies_metadata(
.returning(metadata_table)
)
ingested_df = execute_result_returning_orm_query(engine=engine, select_query=insert_statement)
task_logger.info(f"Geographies ingested: {len(ingested_df)}")
log_as_info(task_logger, f"Geographies ingested: {len(ingested_df)}")
return "success"


Expand All @@ -309,7 +313,7 @@ def ingest_dataset_groups_metadata(
insert(metadata_table).values(groups_df.to_dict(orient="records")).returning(metadata_table)
)
ingested_df = execute_result_returning_orm_query(engine=engine, select_query=insert_statement)
task_logger.info(f"Groups ingested: {len(ingested_df)}")
log_as_info(task_logger, f"Groups ingested: {len(ingested_df)}")
return "success"


Expand All @@ -332,7 +336,7 @@ def ingest_dataset_tags_metadata(
insert(metadata_table).values(tags_df.to_dict(orient="records")).returning(metadata_table)
)
ingested_df = execute_result_returning_orm_query(engine=engine, select_query=insert_statement)
task_logger.info(f"Tags ingested: {len(ingested_df)}")
log_as_info(task_logger, f"Tags ingested: {len(ingested_df)}")
return "success"


Expand Down Expand Up @@ -390,18 +394,18 @@ def request_and_ingest_dataset(
engine = get_pg_engine(conn_id=conn_id)

dataset_df = census_dataset.api_call_obj.make_api_call()
task_logger.info(f"Rows in returned dataset: {len(dataset_df)}")
task_logger.info(f"Columns in returned dataset: {dataset_df.columns}")
log_as_info(task_logger, f"Rows in returned dataset: {len(dataset_df)}")
log_as_info(task_logger, f"Columns in returned dataset: {dataset_df.columns}")
dataset_df["dataset_base_url"] = census_dataset.api_call_obj.dataset_base_url
dataset_df["dataset_id"] = freshness_check.source_freshness["id"].max()
dataset_df["source_data_updated"] = freshness_check.source_freshness[
"source_data_last_modified"
].max()
dataset_df["ingestion_check_time"] = freshness_check.source_freshness["time_of_check"].max()
task_logger.info(f"""dataset_id: {dataset_df["dataset_id"]}""")
task_logger.info(f"""dataset_base_url: {dataset_df["dataset_base_url"]}""")
task_logger.info(f"""source_data_updated: {dataset_df["source_data_updated"]}""")
task_logger.info(f"""ingestion_check_time: {dataset_df["ingestion_check_time"]}""")
log_as_info(task_logger, f"""dataset_id: {dataset_df["dataset_id"]}""")
log_as_info(task_logger, f"""dataset_base_url: {dataset_df["dataset_base_url"]}""")
log_as_info(task_logger, f"""source_data_updated: {dataset_df["source_data_updated"]}""")
log_as_info(task_logger, f"""ingestion_check_time: {dataset_df["ingestion_check_time"]}""")
dataset_df = standardize_column_names(df=dataset_df)
result = dataset_df.to_sql(
name=f"temp_{census_dataset.dataset_name}",
Expand All @@ -411,7 +415,7 @@ def request_and_ingest_dataset(
if_exists="replace",
chunksize=100000,
)
task_logger.info(f"Ingestion result: {result}")
log_as_info(task_logger, f"Ingestion result: {result}")
if result is not None:
return "ingested"
else:
Expand All @@ -431,7 +435,7 @@ def record_data_update(conn_id: str, task_logger: Logger, **kwargs) -> str:
engine=engine,
query=f"""SELECT * FROM metadata.dataset_metadata WHERE id = {dataset_id}""",
)
task_logger.info(f"General metadata record pre-update: {pre_update_record}")
log_as_info(task_logger, f"General metadata record pre-update: {pre_update_record}")
metadata_table = get_reflected_db_table(
engine=engine, table_name="dataset_metadata", schema_name="metadata"
)
Expand All @@ -441,12 +445,12 @@ def record_data_update(conn_id: str, task_logger: Logger, **kwargs) -> str:
.values(local_data_updated=True)
)
execute_dml_orm_query(engine=engine, dml_stmt=update_query, logger=task_logger)
task_logger.info(f"dataset_id: {dataset_id}")
log_as_info(task_logger, f"dataset_id: {dataset_id}")
post_update_record = execute_result_returning_query(
engine=engine,
query=f"""SELECT * FROM metadata.dataset_metadata WHERE id = {dataset_id}""",
)
task_logger.info(f"General metadata record post-update: {post_update_record}")
log_as_info(task_logger, f"General metadata record post-update: {post_update_record}")
return "success"


Expand All @@ -469,10 +473,12 @@ def register_temp_table_asset(
def table_checkpoint_exists(census_dataset: CensusVariableGroupDataset, task_logger: Logger) -> str:
checkpoint_name = f"data_raw.temp_{census_dataset.dataset_name}"
if check_if_checkpoint_exists(checkpoint_name=checkpoint_name, task_logger=task_logger):
task_logger.info(f"GE checkpoint for {checkpoint_name} exists")
log_as_info(task_logger, f"GE checkpoint for {checkpoint_name} exists")
return "update_census_table.raw_data_validation_tg.run_temp_table_checkpoint"
else:
task_logger.info(f"GE checkpoint for {checkpoint_name} doesn't exist yet. Make it maybe?")
log_as_info(
task_logger, f"GE checkpoint for {checkpoint_name} doesn't exist yet. Make it maybe?"
)
return "update_census_table.raw_data_validation_tg.validation_endpoint"


Expand All @@ -484,11 +490,12 @@ def run_temp_table_checkpoint(
checkpoint_run_results = run_checkpoint(
checkpoint_name=checkpoint_name, task_logger=task_logger
)
task_logger.info(
f"list_validation_results: {checkpoint_run_results.list_validation_results()}"
log_as_info(
task_logger,
f"list_validation_results: {checkpoint_run_results.list_validation_results()}",
)
task_logger.info(f"validation success: {checkpoint_run_results.success}")
task_logger.info(f"dir(checkpoint_run_results): {dir(checkpoint_run_results)}")
log_as_info(task_logger, f"validation success: {checkpoint_run_results.success}")
log_as_info(task_logger, f"dir(checkpoint_run_results): {dir(checkpoint_run_results)}")
return checkpoint_run_results


Expand All @@ -503,7 +510,7 @@ def raw_data_validation_tg(
datasource_name: str,
task_logger: Logger,
):
task_logger.info(f"Entered raw_data_validation_tg task_group")
log_as_info(task_logger, f"Entered raw_data_validation_tg task_group")
register_asset_1 = register_temp_table_asset(
census_dataset=census_dataset, datasource_name=datasource_name, task_logger=task_logger
)
Expand All @@ -528,12 +535,12 @@ def table_exists_in_data_raw(
tables_in_data_raw_schema = get_data_table_names_in_schema(
engine=get_pg_engine(conn_id=conn_id), schema_name="data_raw"
)
task_logger.info(f"tables_in_data_raw_schema: {tables_in_data_raw_schema}")
log_as_info(task_logger, f"tables_in_data_raw_schema: {tables_in_data_raw_schema}")
if census_dataset.dataset_name not in tables_in_data_raw_schema:
task_logger.info(f"Table {census_dataset.dataset_name} not in data_raw; creating.")
log_as_info(task_logger, f"Table {census_dataset.dataset_name} not in data_raw; creating.")
return "update_census_table.persist_new_raw_data_tg.create_table_in_data_raw"
else:
task_logger.info(f"Table {census_dataset.dataset_name} in data_raw; skipping.")
log_as_info(task_logger, f"Table {census_dataset.dataset_name} in data_raw; skipping.")
return "update_census_table.persist_new_raw_data_tg.dbt_data_raw_model_exists"


Expand All @@ -543,7 +550,7 @@ def create_table_in_data_raw(
) -> str:
try:
table_name = census_dataset.dataset_name
task_logger.info(f"Creating table data_raw.{table_name}")
log_as_info(task_logger, f"Creating table data_raw.{table_name}")
postgres_hook = PostgresHook(postgres_conn_id=conn_id)
conn = postgres_hook.get_conn()
cur = conn.cursor()
Expand All @@ -563,8 +570,8 @@ def dbt_data_raw_model_exists(
census_dataset: CensusVariableGroupDataset, task_logger: Logger
) -> str:
dbt_data_raw_model_dir = Path(f"/opt/airflow/dbt/models/data_raw")
task_logger.info(f"dbt data_raw model dir ('{dbt_data_raw_model_dir}')")
task_logger.info(f"Dir exists? {dbt_data_raw_model_dir.is_dir()}")
log_as_info(task_logger, f"dbt data_raw model dir ('{dbt_data_raw_model_dir}')")
log_as_info(task_logger, f"Dir exists? {dbt_data_raw_model_dir.is_dir()}")
table_model_path = dbt_data_raw_model_dir.joinpath(f"{census_dataset.dataset_name}.sql")
if table_model_path.is_file():
return "update_census_table.persist_new_raw_data_tg.update_data_raw_table"
Expand All @@ -579,25 +586,18 @@ def make_dbt_data_raw_model(
make_dbt_data_raw_model_file(
table_name=census_dataset.dataset_name, engine=get_pg_engine(conn_id=conn_id)
)
task_logger.info(f"Leaving make_dbt_data_raw_model")
log_as_info(task_logger, f"Leaving make_dbt_data_raw_model")
return "dbt_file_made"


@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
def update_data_raw_table(census_dataset: CensusVariableGroupDataset, task_logger: Logger) -> str:
dbt_cmd = f"""cd /opt/airflow/dbt && \
dbt --warn-error-options \
'{{"include": "all", "exclude": [UnusedResourceConfigPath]}}' \
run --select re_dbt.data_raw.{census_dataset.dataset_name}"""
task_logger.info(f"dbt run command: {dbt_cmd}")
subproc_output = subprocess.run(dbt_cmd, shell=True, capture_output=True, text=True)
raise_exception = False
for el in subproc_output.stdout.split("\n"):
task_logger.info(f"{el}")
if re.search("(\\d* of \\d* ERROR)", el):
raise_exception = True
if raise_exception:
raise Exception("dbt model failed. Review the above outputs")
result = run_dbt_dataset_transformations(
dataset_name=census_dataset.dataset_name,
task_logger=task_logger,
schema="data_raw",
)
log_as_info(task_logger, f"dbt transform result: {result}")
return "data_raw_updated"


Expand Down
19 changes: 9 additions & 10 deletions airflow/dags/tasks/socrata_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
format_dbt_stub_for_clean_stage,
)
from cc_utils.socrata import SocrataTable, SocrataTableMetadata
from cc_utils.transform import run_dbt_dataset_transformations
from cc_utils.utils import (
get_local_data_raw_dir,
get_lines_in_geojson_file,
Expand Down Expand Up @@ -563,20 +564,18 @@ def make_dbt_data_raw_model(conn_id: str, task_logger: Logger, **kwargs) -> Socr


@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
def update_data_raw_table(task_logger: Logger, **kwargs) -> SocrataTableMetadata:
def update_data_raw_table(task_logger: Logger, **kwargs) -> str:
ti = kwargs["ti"]
socrata_metadata = ti.xcom_pull(
task_ids="update_socrata_table.raw_data_validation_tg.validation_endpoint"
)
dbt_cmd = f"""cd /opt/airflow/dbt && \
dbt --warn-error-options \
'{{"include": "all", "exclude": [UnusedResourceConfigPath]}}' \
run --select re_dbt.data_raw.{socrata_metadata.table_name}"""
log_as_info(task_logger, f"dbt run command: {dbt_cmd}")
subproc_output = subprocess.run(dbt_cmd, shell=True, capture_output=True, text=True)
for el in subproc_output.stdout.split("\n"):
log_as_info(task_logger, f"{el}")
return socrata_metadata
result = run_dbt_dataset_transformations(
dataset_name=socrata_metadata.table_name,
task_logger=task_logger,
schema="data_raw",
)
log_as_info(task_logger, f"dbt transform result: {result}")
return "data_raw_updated"


@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
Expand Down
Loading

0 comments on commit ff2e9e6

Please sign in to comment.