diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ede09b388..a36c1b26b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,7 +57,7 @@ repos: - --py37-plus - --keep-runtime-typing - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.8.3 hooks: - id: ruff args: diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 6605bf20d..3e3103266 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -99,6 +99,7 @@ def create_test_task_metadata( extra_context = {} task_owner = "" + airflow_task_config = {} if test_indirect_selection != TestIndirectSelection.EAGER: task_args["indirect_selection"] = test_indirect_selection.value if node is not None: @@ -111,6 +112,7 @@ def create_test_task_metadata( extra_context = {"dbt_node_config": node.context_dict} task_owner = node.owner + airflow_task_config = node.airflow_task_config elif render_config is not None: # TestBehavior.AFTER_ALL task_args["select"] = render_config.select @@ -120,6 +122,7 @@ def create_test_task_metadata( return TaskMetadata( id=test_task_name, owner=task_owner, + airflow_task_config=airflow_task_config, operator_class=calculate_operator_class( execution_mode=execution_mode, dbt_class="DbtTest", @@ -214,6 +217,7 @@ def create_task_metadata( task_metadata = TaskMetadata( id=task_id, owner=node.owner, + airflow_task_config=node.airflow_task_config, operator_class=calculate_operator_class( execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type] ), diff --git a/cosmos/constants.py b/cosmos/constants.py index b45170445..8378e8d10 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -10,6 +10,7 @@ DEFAULT_DBT_PROFILE_NAME = "cosmos_profile" DEFAULT_DBT_TARGET_NAME = "cosmos_target" DEFAULT_COSMOS_CACHE_DIR_NAME = "cosmos" +DEFAULT_TARGET_PATH = "target" DBT_LOG_PATH_ENVVAR = "DBT_LOG_PATH" DBT_LOG_DIR_NAME = "logs" DBT_TARGET_PATH_ENVVAR = "DBT_TARGET_PATH" diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 6f1064649..e25404aed 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -32,6 +32,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) if task.owner != "": task_kwargs["owner"] = task.owner + for k, v in task.airflow_task_config.items(): + task_kwargs[k] = v + airflow_task = Operator( task_id=task.id, dag=dag, diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index 6bf9ff046..cdd5485a6 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -58,6 +58,7 @@ class Task(CosmosEntity): """ owner: str = "" + airflow_task_config: Dict[str, Any] = field(default_factory=dict) operator_class: str = "airflow.operators.empty.EmptyOperator" arguments: Dict[str, Any] = field(default_factory=dict) extra_context: Dict[str, Any] = field(default_factory=dict) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index be37ec298..04a7425e7 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -13,7 +13,7 @@ from functools import cached_property from pathlib import Path from subprocess import PIPE, Popen -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from airflow.models import Variable @@ -67,6 +67,33 @@ class DbtNode: has_freshness: bool = False has_test: bool = False + @property + def airflow_task_config(self) -> Dict[str, Any]: + """ + This method is designed to extend the dbt project's functionality by incorporating Airflow-related metadata into the dbt YAML configuration. + Since dbt projects are independent of Airflow, adding Airflow-specific information to the `meta` field within the dbt YAML allows Airflow tasks to + utilize this information during execution. + + Examples: pool, pool_slots, queue, ... + Returns: + Dict[str, Any]: A dictionary containing custom metadata configurations for integration with Airflow. + """ + + if "meta" in self.config: + meta = self.config["meta"] + if "cosmos" in meta: + cosmos = meta["cosmos"] + if isinstance(cosmos, dict): + if "operator_kwargs" in cosmos: + operator_kwargs = cosmos["operator_kwargs"] + if isinstance(operator_kwargs, dict): + return operator_kwargs + else: + logger.error(f"Invalid type: 'operator_kwargs' in meta.cosmos must be a dict.") + else: + logger.error(f"Invalid type: 'cosmos' in meta must be a dict.") + return {} + @property def resource_name(self) -> str: """ @@ -133,6 +160,7 @@ def is_freshness_effective(freshness: Optional[dict[str, Any]]) -> bool: def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str: """Run a command in a subprocess, returning the stdout.""" + command = [str(arg) if arg is not None else "" for arg in command] logger.info("Running command: `%s`", " ".join(command)) logger.debug("Environment variable keys: %s", env_vars.keys()) process = Popen( diff --git a/cosmos/io.py b/cosmos/io.py new file mode 100644 index 000000000..0cce873e5 --- /dev/null +++ b/cosmos/io.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from cosmos import settings +from cosmos.constants import DEFAULT_TARGET_PATH, FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP +from cosmos.exceptions import CosmosValueError +from cosmos.settings import remote_target_path, remote_target_path_conn_id + + +def upload_to_aws_s3( + project_dir: str, + bucket_name: str, + aws_conn_id: str | None = None, + source_subpath: str = DEFAULT_TARGET_PATH, + **kwargs: Any, +) -> None: + """ + Helper function demonstrating how to upload files to AWS S3 that can be used as a callback. + + :param project_dir: Path of the cloned project directory which Cosmos tasks work from. + :param bucket_name: Name of the S3 bucket to upload to. + :param aws_conn_id: AWS connection ID to use when uploading files. + :param source_subpath: Path of the source directory sub-path to upload files from. + """ + from airflow.providers.amazon.aws.hooks.s3 import S3Hook + + target_dir = f"{project_dir}/{source_subpath}" + aws_conn_id = aws_conn_id if aws_conn_id else S3Hook.default_conn_name + hook = S3Hook(aws_conn_id=aws_conn_id) + context = kwargs["context"] + + # Iterate over the files in the target dir and upload them to S3 + for dirpath, _, filenames in os.walk(target_dir): + for filename in filenames: + s3_key = ( + f"{context['dag'].dag_id}" + f"/{context['run_id']}" + f"/{context['task_instance'].task_id}" + f"/{context['task_instance']._try_number}" + f"{dirpath.split(project_dir)[-1]}/{filename}" + ) + hook.load_file( + filename=f"{dirpath}/{filename}", + bucket_name=bucket_name, + key=s3_key, + replace=True, + ) + + +def upload_to_gcp_gs( + project_dir: str, + bucket_name: str, + gcp_conn_id: str | None = None, + source_subpath: str = DEFAULT_TARGET_PATH, + **kwargs: Any, +) -> None: + """ + Helper function demonstrating how to upload files to GCP GS that can be used as a callback. + + :param project_dir: Path of the cloned project directory which Cosmos tasks work from. + :param bucket_name: Name of the GCP GS bucket to upload to. + :param gcp_conn_id: GCP connection ID to use when uploading files. + :param source_subpath: Path of the source directory sub-path to upload files from. + """ + from airflow.providers.google.cloud.hooks.gcs import GCSHook + + target_dir = f"{project_dir}/{source_subpath}" + gcp_conn_id = gcp_conn_id if gcp_conn_id else GCSHook.default_conn_name + # bucket_name = kwargs["bucket_name"] + hook = GCSHook(gcp_conn_id=gcp_conn_id) + context = kwargs["context"] + + # Iterate over the files in the target dir and upload them to GCP GS + for dirpath, _, filenames in os.walk(target_dir): + for filename in filenames: + object_name = ( + f"{context['dag'].dag_id}" + f"/{context['run_id']}" + f"/{context['task_instance'].task_id}" + f"/{context['task_instance']._try_number}" + f"{dirpath.split(project_dir)[-1]}/{filename}" + ) + hook.upload( + filename=f"{dirpath}/{filename}", + bucket_name=bucket_name, + object_name=object_name, + ) + + +def upload_to_azure_wasb( + project_dir: str, + container_name: str, + azure_conn_id: str | None = None, + source_subpath: str = DEFAULT_TARGET_PATH, + **kwargs: Any, +) -> None: + """ + Helper function demonstrating how to upload files to Azure WASB that can be used as a callback. + + :param project_dir: Path of the cloned project directory which Cosmos tasks work from. + :param container_name: Name of the Azure WASB container to upload files to. + :param azure_conn_id: Azure connection ID to use when uploading files. + :param source_subpath: Path of the source directory sub-path to upload files from. + """ + from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + + target_dir = f"{project_dir}/{source_subpath}" + azure_conn_id = azure_conn_id if azure_conn_id else WasbHook.default_conn_name + # container_name = kwargs["container_name"] + hook = WasbHook(wasb_conn_id=azure_conn_id) + context = kwargs["context"] + + # Iterate over the files in the target dir and upload them to WASB container + for dirpath, _, filenames in os.walk(target_dir): + for filename in filenames: + blob_name = ( + f"{context['dag'].dag_id}" + f"/{context['run_id']}" + f"/{context['task_instance'].task_id}" + f"/{context['task_instance']._try_number}" + f"{dirpath.split(project_dir)[-1]}/{filename}" + ) + hook.load_file( + file_path=f"{dirpath}/{filename}", + container_name=container_name, + blob_name=blob_name, + overwrite=True, + ) + + +def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: + """Configure the remote target path if it is provided.""" + from airflow.version import version as airflow_version + + if not remote_target_path: + return None, None + + _configured_target_path = None + + target_path_str = str(remote_target_path) + + remote_conn_id = remote_target_path_conn_id + if not remote_conn_id: + target_path_schema = urlparse(target_path_str).scheme + remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment] + if remote_conn_id is None: + return None, None + + if not settings.AIRFLOW_IO_AVAILABLE: + raise CosmosValueError( + f"You're trying to specify remote target path {target_path_str}, but the required " + f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to " + "Airflow 2.8 or later." + ) + + from airflow.io.path import ObjectStoragePath + + _configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id) + + if not _configured_target_path.exists(): # type: ignore[no-untyped-call] + _configured_target_path.mkdir(parents=True, exist_ok=True) + + return _configured_target_path, remote_conn_id + + +def _construct_dest_file_path( + dest_target_dir: Path, + file_path: str, + source_target_dir: Path, + source_subpath: str, + **kwargs: Any, +) -> str: + """ + Construct the destination path for the artifact files to be uploaded to the remote store. + """ + dest_target_dir_str = str(dest_target_dir).rstrip("/") + + context = kwargs["context"] + task_run_identifier = ( + f"{context['dag'].dag_id}" + f"/{context['run_id']}" + f"/{context['task_instance'].task_id}" + f"/{context['task_instance']._try_number}" + ) + rel_path = os.path.relpath(file_path, source_target_dir).lstrip("/") + + return f"{dest_target_dir_str}/{task_run_identifier}/{source_subpath}/{rel_path}" + + +def upload_to_cloud_storage(project_dir: str, source_subpath: str = DEFAULT_TARGET_PATH, **kwargs: Any) -> None: + """ + Helper function demonstrating how to upload files to remote object stores that can be used as a callback. This is + an example of a helper function that can be used if on Airflow >= 2.8 and cosmos configurations like + ``remote_target_path`` and ``remote_target_path_conn_id`` when set can be leveraged. + + :param project_dir: Path of the cloned project directory which Cosmos tasks work from. + :param source_subpath: Path of the source directory sub-path to upload files from. + """ + dest_target_dir, dest_conn_id = _configure_remote_target_path() + + if not dest_target_dir: + raise CosmosValueError("You're trying to upload artifact files, but the remote target path is not configured.") + + from airflow.io.path import ObjectStoragePath + + source_target_dir = Path(project_dir) / f"{source_subpath}" + files = [str(file) for file in source_target_dir.rglob("*") if file.is_file()] + for file_path in files: + dest_file_path = _construct_dest_file_path( + dest_target_dir, file_path, source_target_dir, source_subpath, **kwargs + ) + dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id) + ObjectStoragePath(file_path).copy(dest_object_storage_path) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index bf47ab4aa..2a56c33e3 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -141,6 +141,7 @@ def __init__( invocation_mode: InvocationMode | None = None, install_deps: bool = False, callback: Callable[[str], None] | None = None, + callback_args: dict[str, Any] | None = None, should_store_compiled_sql: bool = True, should_upload_compiled_sql: bool = False, append_env: bool = True, @@ -149,6 +150,7 @@ def __init__( self.task_id = task_id self.profile_config = profile_config self.callback = callback + self.callback_args = callback_args or {} self.compiled_sql = "" self.freshness = "" self.should_store_compiled_sql = should_store_compiled_sql @@ -500,9 +502,10 @@ def run_command( self.store_freshness_json(tmp_project_dir, context) self.store_compiled_sql(tmp_project_dir, context) self.upload_compiled_sql(tmp_project_dir, context) - self.handle_exception(result) if self.callback: - self.callback(tmp_project_dir) + self.callback_args.update({"context": context}) + self.callback(tmp_project_dir, **self.callback_args) + self.handle_exception(result) return result diff --git a/dev/dags/cosmos_callback_dag.py b/dev/dags/cosmos_callback_dag.py new file mode 100644 index 000000000..b1f1b3701 --- /dev/null +++ b/dev/dags/cosmos_callback_dag.py @@ -0,0 +1,60 @@ +""" +An example DAG that uses Cosmos to render a dbt project into an Airflow DAG. +""" + +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ProfileConfig, ProjectConfig +from cosmos.io import upload_to_cloud_storage +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START cosmos_callback_example] +cosmos_callback_dag = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=profile_config, + operator_args={ + "install_deps": True, # install any necessary dependencies before running any dbt command + "full_refresh": True, # used only in dbt commands that support this flag + # -------------------------------------------------------------- + # Callback function to upload files using Airflow Object storage and Cosmos remote_target_path setting on Airflow 2.8 and above + "callback": upload_to_cloud_storage, + # -------------------------------------------------------------- + # Callback function to upload files to AWS S3, works for Airflow < 2.8 too + # "callback": upload_to_aws_s3, + # "callback_args": {"aws_conn_id": "aws_s3_conn", "bucket_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- + # Callback function to upload files to GCP GS, works for Airflow < 2.8 too + # "callback": upload_to_gcp_gs, + # "callback_args": {"gcp_conn_id": "gcp_gs_conn", "bucket_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- + # Callback function to upload files to Azure WASB, works for Airflow < 2.8 too + # "callback": upload_to_azure_wasb, + # "callback_args": {"azure_conn_id": "azure_wasb_conn", "container_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- + }, + # normal dag parameters + schedule_interval="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="cosmos_callback_dag", + default_args={"retries": 2}, +) +# [END cosmos_callback_example] diff --git a/dev/dags/example_operators.py b/dev/dags/example_operators.py index 1c8624a34..1e583b12e 100644 --- a/dev/dags/example_operators.py +++ b/dev/dags/example_operators.py @@ -1,10 +1,13 @@ import os from datetime import datetime from pathlib import Path +from typing import Any from airflow import DAG +from airflow.operators.python import PythonOperator from cosmos import DbtCloneLocalOperator, DbtRunLocalOperator, DbtSeedLocalOperator, ProfileConfig +from cosmos.io import upload_to_aws_s3 DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -18,7 +21,19 @@ profiles_yml_filepath=DBT_PROFILE_PATH, ) + +def check_s3_file(bucket_name: str, file_key: str, aws_conn_id: str = "aws_default", **context: Any) -> bool: + """Check if a file exists in the given S3 bucket.""" + from airflow.providers.amazon.aws.hooks.s3 import S3Hook + + s3_key = f"{context['dag'].dag_id}/{context['run_id']}/seed/0/{file_key}" + print(f"Checking if file {s3_key} exists in S3 bucket...") + hook = S3Hook(aws_conn_id=aws_conn_id) + return hook.check_for_key(key=s3_key, bucket_name=bucket_name) + + with DAG("example_operators", start_date=datetime(2024, 1, 1), catchup=False) as dag: + # [START single_operator_callback] seed_operator = DbtSeedLocalOperator( profile_config=profile_config, project_dir=DBT_PROJ_DIR, @@ -26,7 +41,32 @@ dbt_cmd_flags=["--select", "raw_customers"], install_deps=True, append_env=True, + # -------------------------------------------------------------- + # Callback function to upload artifacts to AWS S3 + callback=upload_to_aws_s3, + callback_args={"aws_conn_id": "aws_s3_conn", "bucket_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- + # Callback function to upload artifacts to GCP GS + # callback=upload_to_gcp_gs, + # callback_args={"gcp_conn_id": "gcp_gs_conn", "bucket_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- + # Callback function to upload artifacts to Azure WASB + # callback=upload_to_azure_wasb, + # callback_args={"azure_conn_id": "azure_wasb_conn", "container_name": "cosmos-artifacts-upload"}, + # -------------------------------------------------------------- ) + # [END single_operator_callback] + + check_file_uploaded_task = PythonOperator( + task_id="check_file_uploaded_task", + python_callable=check_s3_file, + op_kwargs={ + "aws_conn_id": "aws_s3_conn", + "bucket_name": "cosmos-artifacts-upload", + "file_key": "target/run_results.json", + }, + ) + run_operator = DbtRunLocalOperator( profile_config=profile_config, project_dir=DBT_PROJ_DIR, @@ -48,3 +88,4 @@ # [END clone_example] seed_operator >> run_operator >> clone_operator + seed_operator >> check_file_uploaded_task diff --git a/docs/_static/custom_airflow_pool.png b/docs/_static/custom_airflow_pool.png new file mode 100644 index 000000000..4b4163e66 Binary files /dev/null and b/docs/_static/custom_airflow_pool.png differ diff --git a/docs/configuration/callbacks.rst b/docs/configuration/callbacks.rst new file mode 100644 index 000000000..9ffbc246f --- /dev/null +++ b/docs/configuration/callbacks.rst @@ -0,0 +1,51 @@ +.. _callbacks: + +Callbacks +========= + +Cosmos supports callback functions that execute at the end of a task's execution when using ``ExecutionMode.LOCAL``. +These callbacks can be used for various purposes, such as uploading files from the target directory to remote +storage. While this feature has been available for some time, users may not be fully aware of its capabilities. + +With the Cosmos 1.8.0 release, several helper functions were added in the ``cosmos/io.py`` module. These functions +provide examples of callback functions that can be hooked into Cosmos DAGs to upload files from the project’s +target directory to remote cloud storage providers such as AWS S3, GCP GS, and Azure WASB. + +Example: Using Callbacks with a Single Operator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To demonstrate how to specify a callback function for uploading files from the target directory, here’s an example +using a single operator in an Airflow DAG: + +.. literalinclude:: ../../dev/dags/example_operators.py + :language: python + :start-after: [START single_operator_callback] + :end-before: [END single_operator_callback] + +Example: Using Callbacks with ``remote_target_path`` (Airflow 2.8+) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're using Airflow 2.8 or later, you can leverage the :ref:`remote_target_path` configuration to upload files +from the target directory to a remote storage. Below is an example of how to define a callback helper function in your +``DbtDag`` that utilizes this configuration: + +.. literalinclude:: ../../dev/dags/cosmos_callback_dag.py + :language: python + :start-after: [START cosmos_callback_example] + :end-before: [END cosmos_callback_example] + +Custom Callbacks +~~~~~~~~~~~~~~~~ + +The helper functions introduced in Cosmos 1.8.0 are just examples of how callback functions can be written and passed +to Cosmos DAGs. Users are not limited to using these predefined functions — they can also create their own custom +callback functions to meet specific needs. These custom functions can be provided to Cosmos DAGs, where they will +receive the path to the cloned project directory and the Airflow task context, which includes DAG and task instance +metadata. + +Limitations and Contributions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Currently, callback support is available only when using ``ExecutionMode.LOCAL``. Contributions to extend this +functionality to other execution modes are welcome and encouraged. You can reference the implementation for +``ExecutionMode.LOCAL`` to add support for other modes. diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index f6e60f61b..6c47884e9 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -27,3 +27,4 @@ Cosmos offers a number of configuration options to customize its behavior. For m Compiled SQL Logging Caching + Callbacks diff --git a/docs/getting_started/custom-airflow-properties.rst b/docs/getting_started/custom-airflow-properties.rst new file mode 100644 index 000000000..90490a099 --- /dev/null +++ b/docs/getting_started/custom-airflow-properties.rst @@ -0,0 +1,33 @@ +.. _custom-airflow-properties: + +Airflow Configuration Overrides with Astronomer Cosmos +====================================================== + +**Astronomer Cosmos** allows you to override Airflow configurations for each dbt task (dbt operator) via the dbt YAML file. + +Sample dbt Model YAML +++++++++++++ + +.. code-block:: yaml + + version: 2 + models: + - name: name + description: description + meta: + cosmos: + operator_args: + pool: abcd + + + + +Explanation +++++++++++++ + +By adding Airflow configurations under **cosmos** in the **meta** field, you can set independent Airflow configurations for each task. +For example, in the YAML above, the **pool** setting is applied to the specific dbt task. +This approach allows for more granular control over Airflow settings per task within your dbt model definitions. + +.. image:: ../_static/custom_airflow_pool.png + :alt: Result of applying Custom Airflow Pool diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index d2f943bab..fc0070e8b 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -62,7 +62,7 @@ depends_on=[parent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen3/models/child.sql", tags=["nightly"], - config={"materialized": "table"}, + config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"queue": "custom_queue"}}}}, ) child2_node = DbtNode( @@ -71,7 +71,7 @@ depends_on=[parent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen3/models/child2_v2.sql", tags=["nightly"], - config={"materialized": "table"}, + config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}}, ) sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, child2_node] @@ -750,3 +750,42 @@ def test_owner(dbt_extra_config, expected_owner): assert len(output.leaves) == 1 assert output.leaves[0].owner == expected_owner + + +def test_custom_meta(): + with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + render_config=RenderConfig( + test_behavior=TestBehavior.AFTER_EACH, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ), + dbt_project_name="astro_shop", + ) + # test custom meta (queue, pool) + for task in dag.tasks: + if task.task_id == "child2_v2_run": + assert task.pool == "custom_pool" + else: + assert task.pool == "default_pool" + + if task.task_id == "child_run": + assert task.queue == "custom_queue" + else: + assert task.queue == "default" diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 59d71869d..499cc219c 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -1121,6 +1121,20 @@ def test_run_command(mock_popen, stdout, returncode): assert return_value == stdout +@patch("cosmos.dbt.graph.Popen") +def test_run_command_none_argument(mock_popen, caplog): + fake_command = ["invalid-cmd", None] + fake_dir = Path("fake_dir") + env_vars = {"fake": "env_var"} + + mock_popen.return_value.communicate.return_value = ("Invalid None argument", None) + with pytest.raises(CosmosLoadDbtException) as exc_info: + run_command(fake_command, fake_dir, env_vars) + + expected = "Unable to run ['invalid-cmd', ''] due to the error:\nInvalid None argument" + assert str(exc_info.value) == expected + + def test_parse_dbt_ls_output_real_life_customer_bug(caplog): dbt_ls_output = """ 11:20:43 Running with dbt=1.7.6 diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index 5d7a1a70d..762985b59 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -30,7 +30,7 @@ MIN_VER_DAG_FILE: dict[str, list[str]] = { "2.4": ["cosmos_seed_dag.py"], - "2.8": ["cosmos_manifest_example.py", "simple_dag_async.py"], + "2.8": ["cosmos_manifest_example.py", "simple_dag_async.py", "cosmos_callback_dag.py"], } IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py"] diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 000000000..7410f0588 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,168 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from cosmos.constants import DEFAULT_TARGET_PATH, _default_s3_conn +from cosmos.exceptions import CosmosValueError +from cosmos.io import ( + _configure_remote_target_path, + _construct_dest_file_path, + upload_to_aws_s3, + upload_to_azure_wasb, + upload_to_cloud_storage, + upload_to_gcp_gs, +) +from cosmos.settings import AIRFLOW_IO_AVAILABLE + + +@pytest.fixture +def dummy_kwargs(): + """Fixture for reusable test kwargs.""" + return { + "context": { + "dag": MagicMock(dag_id="test_dag"), + "run_id": "test_run_id", + "task_instance": MagicMock(task_id="test_task", _try_number=1), + }, + "bucket_name": "test_bucket", + "container_name": "test_container", + } + + +def test_upload_artifacts_to_aws_s3(dummy_kwargs): + """Test upload_artifacts_to_aws_s3.""" + with patch("airflow.providers.amazon.aws.hooks.s3.S3Hook") as mock_hook, patch("os.walk") as mock_walk: + mock_walk.return_value = [("/target", [], ["file1.txt", "file2.txt"])] + + upload_to_aws_s3("/project_dir", **dummy_kwargs) + + mock_walk.assert_called_once_with("/project_dir/target") + hook_instance = mock_hook.return_value + assert hook_instance.load_file.call_count == 2 + + +def test_upload_artifacts_to_gcp_gs(dummy_kwargs): + """Test upload_artifacts_to_gcp_gs.""" + with patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_hook, patch("os.walk") as mock_walk: + mock_walk.return_value = [("/target", [], ["file1.txt", "file2.txt"])] + + upload_to_gcp_gs("/project_dir", **dummy_kwargs) + + mock_walk.assert_called_once_with("/project_dir/target") + hook_instance = mock_hook.return_value + assert hook_instance.upload.call_count == 2 + + +def test_upload_artifacts_to_azure_wasb(dummy_kwargs): + """Test upload_artifacts_to_azure_wasb.""" + with patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook") as mock_hook, patch("os.walk") as mock_walk: + mock_walk.return_value = [("/target", [], ["file1.txt", "file2.txt"])] + + upload_to_azure_wasb("/project_dir", **dummy_kwargs) + + mock_walk.assert_called_once_with("/project_dir/target") + hook_instance = mock_hook.return_value + assert hook_instance.load_file.call_count == 2 + + +def test_configure_remote_target_path_no_remote_target(): + """Test _configure_remote_target_path when no remote target path is set.""" + with patch("cosmos.settings.remote_target_path", None): + from cosmos.io import _configure_remote_target_path + + assert _configure_remote_target_path() == (None, None) + + +def test_construct_dest_file_path(dummy_kwargs): + """Test _construct_dest_file_path.""" + dest_target_dir = Path("/dest") + source_target_dir = Path("/project_dir/target") + file_path = "/project_dir/target/subdir/file.txt" + + expected_path = "/dest/test_dag/test_run_id/test_task/1/target/subdir/file.txt" + assert ( + _construct_dest_file_path(dest_target_dir, file_path, source_target_dir, DEFAULT_TARGET_PATH, **dummy_kwargs) + == expected_path + ) + + +def test_upload_artifacts_to_cloud_storage_no_remote_path(): + """Test upload_artifacts_to_cloud_storage with no remote path.""" + with patch("cosmos.io._configure_remote_target_path", return_value=(None, None)): + with pytest.raises(CosmosValueError): + upload_to_cloud_storage("/project_dir", **{}) + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +def test_upload_artifacts_to_cloud_storage_success(dummy_kwargs): + """Test upload_artifacts_to_cloud_storage with valid setup.""" + with patch( + "cosmos.io._configure_remote_target_path", + return_value=(Path("/dest"), "conn_id"), + ) as mock_configure, patch("pathlib.Path.rglob") as mock_rglob, patch( + "airflow.io.path.ObjectStoragePath.copy" + ) as mock_copy: + mock_file1 = MagicMock(spec=Path) + mock_file1.is_file.return_value = True + mock_file1.__str__.return_value = "/project_dir/target/file1.txt" + + mock_file2 = MagicMock(spec=Path) + mock_file2.is_file.return_value = True + mock_file2.__str__.return_value = "/project_dir/target/file2.txt" + + mock_rglob.return_value = [mock_file1, mock_file2] + + upload_to_cloud_storage("/project_dir", **dummy_kwargs) + + mock_configure.assert_called_once() + assert mock_copy.call_count == 2 + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.io.remote_target_path") +def test_configure_remote_target_path_no_conn_id(mock_remote_target_path): + """Test when no remote_conn_id is provided, but conn_id is resolved from scheme.""" + mock_remote_target_path.return_value = "s3://bucket/path/to/file" + mock_storage_path = MagicMock() + with patch("cosmos.io.urlparse") as mock_urlparse: + mock_urlparse.return_value.scheme = "s3" + with patch("airflow.io.path.ObjectStoragePath") as mock_object_storage: + mock_object_storage.return_value = mock_storage_path + mock_storage_path.exists.return_value = True + + result = _configure_remote_target_path() + assert result == (mock_object_storage.return_value, _default_s3_conn) + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.io.remote_target_path") +def test_configure_remote_target_path_conn_id_is_none(mock_remote_target_path): + """Test when conn_id cannot be resolved and is None.""" + mock_remote_target_path.return_value = "abcd://bucket/path/to/file" + mock_storage_path = MagicMock() + with patch("cosmos.io.urlparse") as mock_urlparse: + mock_urlparse.return_value.scheme = "abcd" + with patch("airflow.io.path.ObjectStoragePath") as mock_object_storage: + mock_object_storage.return_value = mock_storage_path + mock_storage_path.exists.return_value = True + + result = _configure_remote_target_path() + assert result == (None, None) + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.settings.AIRFLOW_IO_AVAILABLE", False) +@patch("cosmos.io.remote_target_path") +def test_configure_remote_target_path_airflow_io_unavailable(mock_remote_target_path): + """Test when AIRFLOW_IO_AVAILABLE is False.""" + mock_remote_target_path.return_value = "s3://bucket/path/to/file" + mock_storage_path = MagicMock() + with patch("cosmos.io.urlparse") as mock_urlparse: + mock_urlparse.return_value.scheme = "s3" + with patch("airflow.io.path.ObjectStoragePath") as mock_object_storage: + mock_object_storage.return_value = mock_storage_path + mock_storage_path.exists.return_value = True + with pytest.raises(CosmosValueError) as exc_info: + _configure_remote_target_path() + assert "Object Storage feature is unavailable" in str(exc_info.value)