Skip to content

Commit

Permalink
Add helper functions for uploading target directory artifacts to remo…
Browse files Browse the repository at this point in the history
…te cloud storages
  • Loading branch information
pankajkoti committed Dec 13, 2024
1 parent c66681c commit 27ef798
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 2 deletions.
161 changes: 161 additions & 0 deletions cosmos/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Any
from urllib.parse import urlparse

from cosmos import settings
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP
from cosmos.exceptions import CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id


def upload_artifacts_to_aws_s3(project_dir: str, **kwargs: Any) -> None:
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

target_dir = f"{project_dir}/target"
aws_conn_id = kwargs.get("aws_conn_id", S3Hook.default_conn_name)
bucket_name = kwargs["bucket_name"]
hook = S3Hook(aws_conn_id=aws_conn_id)

# Iterate over the files in the target dir and upload them to S3
for dirpath, _, filenames in os.walk(target_dir):
for filename in filenames:
s3_key = (
f"{kwargs['dag'].dag_id}"
f"/{kwargs['run_id']}"
f"/{kwargs['task_instance'].task_id}"
f"/{kwargs['task_instance']._try_number}"
f"{dirpath.split(project_dir)[-1]}/{filename}"
)
hook.load_file(
filename=f"{dirpath}/{filename}",
bucket_name=bucket_name,
key=s3_key,
replace=True,
)


def upload_artifacts_to_gcp_gs(project_dir: str, **kwargs: Any) -> None:
from airflow.providers.google.cloud.hooks.gcs import GCSHook

target_dir = f"{project_dir}/target"
gcp_conn_id = kwargs.get("gcp_conn_id", GCSHook.default_conn_name)
bucket_name = kwargs["bucket_name"]
hook = GCSHook(gcp_conn_id=gcp_conn_id)

# Iterate over the files in the target dir and upload them to GCP GS
for dirpath, _, filenames in os.walk(target_dir):
for filename in filenames:
object_name = (
f"{kwargs['dag'].dag_id}"
f"/{kwargs['run_id']}"
f"/{kwargs['task_instance'].task_id}"
f"/{kwargs['task_instance']._try_number}"
f"{dirpath.split(project_dir)[-1]}/{filename}"
)
hook.upload(
filename=f"{dirpath}/{filename}",
bucket_name=bucket_name,
object_name=object_name,
)


def upload_artifacts_to_azure_wasb(project_dir: str, **kwargs: Any) -> None:
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook

target_dir = f"{project_dir}/target"
azure_conn_id = kwargs.get("azure_conn_id", WasbHook.default_conn_name)
container_name = kwargs["container_name"]
hook = WasbHook(wasb_conn_id=azure_conn_id)

# Iterate over the files in the target dir and upload them to WASB container
for dirpath, _, filenames in os.walk(target_dir):
for filename in filenames:
blob_name = (
f"{kwargs['dag'].dag_id}"
f"/{kwargs['run_id']}"
f"/{kwargs['task_instance'].task_id}"
f"/{kwargs['task_instance']._try_number}"
f"{dirpath.split(project_dir)[-1]}/{filename}"
)
hook.load_file(
file_path=f"{dirpath}/{filename}",
container_name=container_name,
blob_name=blob_name,
overwrite=True,
)


def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
"""Configure the remote target path if it is provided."""
from airflow.version import version as airflow_version

if not remote_target_path:
return None, None

_configured_target_path = None

target_path_str = str(remote_target_path)

remote_conn_id = remote_target_path_conn_id
if not remote_conn_id:
target_path_schema = urlparse(target_path_str).scheme
remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment]
if remote_conn_id is None:
return None, None

if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
"Airflow 2.8 or later."
)

from airflow.io.path import ObjectStoragePath

_configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id)

if not _configured_target_path.exists(): # type: ignore[no-untyped-call]
_configured_target_path.mkdir(parents=True, exist_ok=True)

return _configured_target_path, remote_conn_id


def _construct_dest_file_path(
dest_target_dir: Path,
file_path: str,
source_target_dir: Path,
**kwargs: Any,
) -> str:
"""
Construct the destination path for the artifact files to be uploaded to the remote store.
"""
dest_target_dir_str = str(dest_target_dir).rstrip("/")

task_run_identifier = (
f"{kwargs['dag'].dag_id}"
f"/{kwargs['run_id']}"
f"/{kwargs['task_instance'].task_id}"
f"/{kwargs['task_instance']._try_number}"
)
rel_path = os.path.relpath(file_path, source_target_dir).lstrip("/")

return f"{dest_target_dir_str}/{task_run_identifier}/target/{rel_path}"


def upload_artifacts_to_cloud_storage(project_dir: str, **kwargs: Any) -> None:
dest_target_dir, dest_conn_id = _configure_remote_target_path()

if not dest_target_dir:
raise CosmosValueError("You're trying to upload artifact files, but the remote target path is not configured.")

from airflow.io.path import ObjectStoragePath

source_target_dir = Path(project_dir) / "target"
files = [str(file) for file in source_target_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = _construct_dest_file_path(dest_target_dir, file_path, source_target_dir, **kwargs)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
7 changes: 5 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
invocation_mode: InvocationMode | None = None,
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
callback_args: dict[str, Any] | None = None,
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
Expand All @@ -149,6 +150,7 @@ def __init__(
self.task_id = task_id
self.profile_config = profile_config
self.callback = callback
self.callback_args = callback_args or {}
self.compiled_sql = ""
self.freshness = ""
self.should_store_compiled_sql = should_store_compiled_sql
Expand Down Expand Up @@ -500,9 +502,10 @@ def run_command(
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
self.handle_exception(result)
if self.callback:
self.callback(tmp_project_dir)
self.callback_args.update(context)
self.callback(tmp_project_dir, **self.callback_args)
self.handle_exception(result)

return result

Expand Down
19 changes: 19 additions & 0 deletions dev/dags/basic_cosmos_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from pathlib import Path

from cosmos import DbtDag, ProfileConfig, ProjectConfig
from cosmos.helpers import (
upload_artifacts_to_cloud_storage,
)
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
Expand All @@ -32,6 +35,22 @@
operator_args={
"install_deps": True, # install any necessary dependencies before running any dbt command
"full_refresh": True, # used only in dbt commands that support this flag
# --------------------------------------------------------------
# Callback function to upload artifacts using Airflow Object storage and Cosmos remote_target_path setting on Airflow 2.8 and above
"callback": upload_artifacts_to_cloud_storage,
# --------------------------------------------------------------
# Callback function to upload artifacts to AWS S3 for Airflow < 2.8
# "callback": upload_artifacts_to_aws_s3,
# "callback_args": {"aws_conn_id": "aws_s3_conn", "bucket_name": "cosmos-artifacts-upload"},
# --------------------------------------------------------------
# Callback function to upload artifacts to GCP GS for Airflow < 2.8
# "callback": upload_artifacts_to_gcp_gs,
# "callback_args": {"gcp_conn_id": "gcp_gs_conn", "bucket_name": "cosmos-artifacts-upload"},
# --------------------------------------------------------------
# Callback function to upload artifacts to Azure WASB for Airflow < 2.8
# "callback": upload_artifacts_to_azure_wasb,
# "callback_args": {"azure_conn_id": "azure_wasb_conn", "container_name": "cosmos-artifacts-upload"},
# --------------------------------------------------------------
},
# normal dag parameters
schedule_interval="@daily",
Expand Down
3 changes: 3 additions & 0 deletions dev/dags/example_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from airflow import DAG

from cosmos import DbtCloneLocalOperator, DbtRunLocalOperator, DbtSeedLocalOperator, ProfileConfig
from cosmos.helpers import upload_artifacts_to_aws_s3

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
Expand All @@ -26,6 +27,8 @@
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
callback=upload_artifacts_to_aws_s3,
callback_args={"aws_conn_id": "aws_s3_conn", "bucket_name": "cosmos-artifacts-upload"},
)
run_operator = DbtRunLocalOperator(
profile_config=profile_config,
Expand Down

0 comments on commit 27ef798

Please sign in to comment.