From 682fc67a6afee398a648d6350ec0e8448fe36c6e Mon Sep 17 00:00:00 2001 From: Justin Bandoro <79104794+jbandoro@users.noreply.github.com> Date: Wed, 25 Oct 2023 06:58:56 -0700 Subject: [PATCH] Add `DbtDocsGCSOperator` (#616) Adds `DbtDocsGCSOperator` so dbt docs can be uploaded to GCS. Closes: #541 ## Breaking Change? No breaking changes but standardized `DbtDocsS3LocalOperator`, `DbtDocsAzureStorageLocalOperator` to accept args for `connection_id` and `bucket_name`. The current args of `aws_conn_id` (S3), `azure_conn_id` and `container_name` (Azure) will still work with warnings to switch to `connection_id` and `bucket_name`. --- cosmos/operators/__init__.py | 2 + cosmos/operators/local.py | 141 ++++++++++++++++++------- dev/dags/dbt_docs.py | 23 +++- docs/configuration/generating-docs.rst | 30 +++++- tests/operators/test_local.py | 33 +++++- 5 files changed, 183 insertions(+), 46 deletions(-) diff --git a/cosmos/operators/__init__.py b/cosmos/operators/__init__.py index c3155a9f9..b7e36abff 100644 --- a/cosmos/operators/__init__.py +++ b/cosmos/operators/__init__.py @@ -2,6 +2,7 @@ from .local import DbtDocsAzureStorageLocalOperator as DbtDocsAzureStorageOperator from .local import DbtDocsLocalOperator as DbtDocsOperator from .local import DbtDocsS3LocalOperator as DbtDocsS3Operator +from .local import DbtDocsGCSLocalOperator as DbtDocsGCSOperator from .local import DbtLSLocalOperator as DbtLSOperator from .local import DbtRunLocalOperator as DbtRunOperator from .local import DbtRunOperationLocalOperator as DbtRunOperationOperator @@ -20,4 +21,5 @@ "DbtDocsOperator", "DbtDocsS3Operator", "DbtDocsAzureStorageOperator", + "DbtDocsGCSOperator", ] diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index c033f33f3..489a92ba9 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -7,6 +7,8 @@ from attr import define from pathlib import Path from typing import Any, Callable, Literal, Sequence, TYPE_CHECKING +from abc import ABC, abstractmethod +import warnings import airflow import yaml @@ -539,40 +541,65 @@ def __init__(self, **kwargs: Any) -> None: self.base_cmd = ["docs", "generate"] -class DbtDocsS3LocalOperator(DbtDocsLocalOperator): +class DbtDocsCloudLocalOperator(DbtDocsLocalOperator, ABC): """ - Executes `dbt docs generate` command and upload to S3 storage. Returns the S3 path to the generated documentation. - - :param aws_conn_id: S3's Airflow connection ID - :param bucket_name: S3's bucket name - :param folder_dir: This can be used to specify under which directory the generated DBT documentation should be - uploaded. + Abstract class for operators that upload the generated documentation to cloud storage. """ - ui_color = "#FF9900" - def __init__( self, - aws_conn_id: str, + connection_id: str, bucket_name: str, folder_dir: str | None = None, - **kwargs: str, + **kwargs: Any, ) -> None: "Initializes the operator." - self.aws_conn_id = aws_conn_id + self.connection_id = connection_id self.bucket_name = bucket_name self.folder_dir = folder_dir super().__init__(**kwargs) # override the callback with our own - self.callback = self.upload_to_s3 + self.callback = self.upload_to_cloud_storage + + @abstractmethod + def upload_to_cloud_storage(self, project_dir: str) -> None: + """Abstract method to upload the generated documentation to cloud storage.""" + + +class DbtDocsS3LocalOperator(DbtDocsCloudLocalOperator): + """ + Executes `dbt docs generate` command and upload to S3 storage. Returns the S3 path to the generated documentation. + + :param connection_id: S3's Airflow connection ID + :param bucket_name: S3's bucket name + :param folder_dir: This can be used to specify under which directory the generated DBT documentation should be + uploaded. + """ + + ui_color = "#FF9900" + + def __init__( + self, + *args: Any, + aws_conn_id: str | None = None, + **kwargs: Any, + ) -> None: + if aws_conn_id: + warnings.warn( + "Please, use `connection_id` instead of `aws_conn_id`. The argument `aws_conn_id` will be" + " deprecated in Cosmos 2.0", + DeprecationWarning, + ) + kwargs["connection_id"] = aws_conn_id + super().__init__(*args, **kwargs) - def upload_to_s3(self, project_dir: str) -> None: + def upload_to_cloud_storage(self, project_dir: str) -> None: "Uploads the generated documentation to S3." logger.info( 'Attempting to upload generated docs to S3 using S3Hook("%s")', - self.aws_conn_id, + self.connection_id, ) from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -580,7 +607,7 @@ def upload_to_s3(self, project_dir: str) -> None: target_dir = f"{project_dir}/target" hook = S3Hook( - self.aws_conn_id, + self.connection_id, extra_args={ "ContentType": "text/html", }, @@ -599,12 +626,12 @@ def upload_to_s3(self, project_dir: str) -> None: ) -class DbtDocsAzureStorageLocalOperator(DbtDocsLocalOperator): +class DbtDocsAzureStorageLocalOperator(DbtDocsCloudLocalOperator): """ Executes `dbt docs generate` command and upload to Azure Blob Storage. - :param azure_conn_id: Azure Blob Storage's Airflow connection ID - :param container_name: Azure Blob Storage's bucket name + :param connection_id: Azure Blob Storage's Airflow connection ID + :param bucket_name: Azure Blob Storage's bucket name :param folder_dir: This can be used to specify under which directory the generated DBT documentation should be uploaded. """ @@ -613,26 +640,32 @@ class DbtDocsAzureStorageLocalOperator(DbtDocsLocalOperator): def __init__( self, - azure_conn_id: str, - container_name: str, - folder_dir: str | None = None, - **kwargs: str, + *args: Any, + azure_conn_id: str | None = None, + container_name: str | None = None, + **kwargs: Any, ) -> None: - "Initializes the operator." - self.azure_conn_id = azure_conn_id - self.container_name = container_name - self.folder_dir = folder_dir - - super().__init__(**kwargs) - - # override the callback with our own - self.callback = self.upload_to_azure + if azure_conn_id: + warnings.warn( + "Please, use `connection_id` instead of `azure_conn_id`. The argument `azure_conn_id` will" + " be deprecated in Cosmos 2.0", + DeprecationWarning, + ) + kwargs["connection_id"] = azure_conn_id + if container_name: + warnings.warn( + "Please, use `bucket_name` instead of `container_name`. The argument `container_name` will" + " be deprecated in Cosmos 2.0", + DeprecationWarning, + ) + kwargs["bucket_name"] = container_name + super().__init__(*args, **kwargs) - def upload_to_azure(self, project_dir: str) -> None: + def upload_to_cloud_storage(self, project_dir: str) -> None: "Uploads the generated documentation to Azure Blob Storage." logger.info( 'Attempting to upload generated docs to Azure Blob Storage using WasbHook(conn_id="%s")', - self.azure_conn_id, + self.connection_id, ) from airflow.providers.microsoft.azure.hooks.wasb import WasbHook @@ -640,26 +673,60 @@ def upload_to_azure(self, project_dir: str) -> None: target_dir = f"{project_dir}/target" hook = WasbHook( - self.azure_conn_id, + self.connection_id, ) for filename in self.required_files: logger.info( "Uploading %s to %s", filename, - f"wasb://{self.container_name}/{filename}", + f"wasb://{self.bucket_name}/{filename}", ) blob_name = f"{self.folder_dir}/{filename}" if self.folder_dir else filename hook.load_file( file_path=f"{target_dir}/{filename}", - container_name=self.container_name, + container_name=self.bucket_name, blob_name=blob_name, overwrite=True, ) +class DbtDocsGCSLocalOperator(DbtDocsCloudLocalOperator): + """ + Executes `dbt docs generate` command and upload to GCS. + + :param connection_id: Google Cloud Storage's Airflow connection ID + :param bucket_name: Google Cloud Storage's bucket name + :param folder_dir: This can be used to specify under which directory the generated DBT documentation should be + uploaded. + """ + + ui_color = "#4772d5" + + def upload_to_cloud_storage(self, project_dir: str) -> None: + "Uploads the generated documentation to Google Cloud Storage" + logger.info( + 'Attempting to upload generated docs to Storage using GCSHook(conn_id="%s")', + self.connection_id, + ) + + from airflow.providers.google.cloud.hooks.gcs import GCSHook + + target_dir = f"{project_dir}/target" + hook = GCSHook(self.connection_id) + + for filename in self.required_files: + blob_name = f"{self.folder_dir}/{filename}" if self.folder_dir else filename + logger.info("Uploading %s to %s", filename, f"gs://{self.bucket_name}/{blob_name}") + hook.upload( + filename=f"{target_dir}/{filename}", + bucket_name=self.bucket_name, + object_name=blob_name, + ) + + class DbtDepsLocalOperator(DbtLocalBaseOperator): """ Executes a dbt core deps command. diff --git a/dev/dags/dbt_docs.py b/dev/dags/dbt_docs.py index 1fcd1c341..edf89bdab 100644 --- a/dev/dags/dbt_docs.py +++ b/dev/dags/dbt_docs.py @@ -20,6 +20,7 @@ from cosmos.operators import ( DbtDocsAzureStorageOperator, DbtDocsS3Operator, + DbtDocsGCSOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping @@ -28,6 +29,7 @@ S3_CONN_ID = "aws_docs" AZURE_CONN_ID = "azure_docs" +GCS_CONN_ID = "gcs_docs" profile_config = ProfileConfig( profile_name="default", @@ -56,6 +58,11 @@ def which_upload(): downstream_tasks_to_run += ["generate_dbt_docs_azure"] except AirflowNotFoundException: pass + try: + BaseHook.get_connection(GCS_CONN_ID) + downstream_tasks_to_run += ["generate_dbt_docs_gcs"] + except AirflowNotFoundException: + pass return downstream_tasks_to_run @@ -72,7 +79,7 @@ def which_upload(): task_id="generate_dbt_docs_aws", project_dir=DBT_ROOT_PATH / "jaffle_shop", profile_config=profile_config, - aws_conn_id=S3_CONN_ID, + connection_id=S3_CONN_ID, bucket_name="cosmos-docs", ) @@ -80,8 +87,16 @@ def which_upload(): task_id="generate_dbt_docs_azure", project_dir=DBT_ROOT_PATH / "jaffle_shop", profile_config=profile_config, - azure_conn_id=AZURE_CONN_ID, - container_name="$web", + connection_id=AZURE_CONN_ID, + bucket_name="$web", + ) + + generate_dbt_docs_gcs = DbtDocsGCSOperator( + task_id="generate_dbt_docs_gcs", + project_dir=DBT_ROOT_PATH / "jaffle_shop", + profile_config=profile_config, + connection_id=GCS_CONN_ID, + bucket_name="cosmos-docs", ) - which_upload() >> [generate_dbt_docs_aws, generate_dbt_docs_azure] + which_upload() >> [generate_dbt_docs_aws, generate_dbt_docs_azure, generate_dbt_docs_gcs] diff --git a/docs/configuration/generating-docs.rst b/docs/configuration/generating-docs.rst index 925b60e04..88459fd14 100644 --- a/docs/configuration/generating-docs.rst +++ b/docs/configuration/generating-docs.rst @@ -11,9 +11,10 @@ Cosmos offers two pre-built ways of generating and uploading dbt docs and a fall - :class:`~cosmos.operators.DbtDocsS3Operator`: generates and uploads docs to a S3 bucket. - :class:`~cosmos.operators.DbtDocsAzureStorageOperator`: generates and uploads docs to an Azure Blob Storage. +- :class:`~cosmos.operators.DbtDocsGCSOperator`: generates and uploads docs to a GCS bucket. - :class:`~cosmos.operators.DbtDocsOperator`: generates docs and runs a custom callback. -The first two operators require you to have a connection to the target storage. The third operator allows you to run custom code after the docs are generated in order to upload them to a storage of your choice. +The first three operators require you to have a connection to the target storage. The last operator allows you to run custom code after the docs are generated in order to upload them to a storage of your choice. Examples @@ -36,7 +37,7 @@ You can use the :class:`~cosmos.operators.DbtDocsS3Operator` to generate and upl project_dir="path/to/jaffle_shop", profile_config=profile_config, # docs-specific arguments - aws_conn_id="test_aws", + connection_id="test_aws", bucket_name="test_bucket", ) @@ -57,8 +58,29 @@ You can use the :class:`~cosmos.operators.DbtDocsAzureStorageOperator` to genera project_dir="path/to/jaffle_shop", profile_config=profile_config, # docs-specific arguments - azure_conn_id="test_azure", - container_name="$web", + connection_id="test_azure", + bucket_name="$web", + ) + +Upload to GCS +~~~~~~~~~~~~~~~~~~~~~~~ + +GCS supports serving static files directly from a bucket. To learn more (and to set it up), check out the `official GCS documentation `_. + +You can use the :class:`~cosmos.operators.DbtDocsGCSOperator` to generate and upload docs to a S3 bucket. The following code snippet shows how to do this with the default jaffle_shop project: + +.. code-block:: python + + from cosmos.operators import DbtDocsGCSOperator + + # then, in your DAG code: + generate_dbt_docs_aws = DbtDocsGCSOperator( + task_id="generate_dbt_docs_gcs", + project_dir="path/to/jaffle_shop", + profile_config=profile_config, + # docs-specific arguments + connection_id="test_gcs", + bucket_name="test_bucket", ) Custom Callback diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 07c186d4a..b883adea5 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -1,9 +1,10 @@ import logging import os +import sys import shutil import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call import pytest from airflow import DAG @@ -24,6 +25,7 @@ DbtDocsLocalOperator, DbtDocsS3LocalOperator, DbtDocsAzureStorageLocalOperator, + DbtDocsGCSLocalOperator, DbtSeedLocalOperator, DbtRunOperationLocalOperator, ) @@ -379,6 +381,7 @@ def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwa DbtDocsLocalOperator, DbtDocsS3LocalOperator, DbtDocsAzureStorageLocalOperator, + DbtDocsGCSLocalOperator, ), ) @patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd") @@ -386,6 +389,7 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): operator_class_kwargs = { DbtDocsS3LocalOperator: {"aws_conn_id": "fake-conn", "bucket_name": "fake-bucket"}, DbtDocsAzureStorageLocalOperator: {"azure_conn_id": "fake-conn", "container_name": "fake-container"}, + DbtDocsGCSLocalOperator: {"connection_id": "fake-conn", "bucket_name": "fake-bucket"}, } task = operator_class( profile_config=profile_config, @@ -413,3 +417,30 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo assert instance.parse.called err_msg = "Unable to parse OpenLineage events" assert err_msg in caplog.text + + +@patch.object(DbtDocsGCSLocalOperator, "required_files", ["file1", "file2"]) +def test_dbt_docs_gcs_local_operator(): + mock_gcs = MagicMock() + with patch.dict(sys.modules, {"airflow.providers.google.cloud.hooks.gcs": mock_gcs}): + operator = DbtDocsGCSLocalOperator( + task_id="fake-task", + project_dir="fake-dir", + profile_config=profile_config, + connection_id="fake-conn", + bucket_name="fake-bucket", + folder_dir="fake-folder", + ) + operator.upload_to_cloud_storage("fake-dir") + + # assert that GCSHook was called with the connection id + mock_gcs.GCSHook.assert_called_once_with("fake-conn") + + mock_hook = mock_gcs.GCSHook.return_value + # assert that upload was called twice with the expected arguments + assert mock_hook.upload.call_count == 2 + expected_upload_calls = [ + call(filename="fake-dir/target/file1", bucket_name="fake-bucket", object_name="fake-folder/file1"), + call(filename="fake-dir/target/file2", bucket_name="fake-bucket", object_name="fake-folder/file2"), + ] + mock_hook.upload.assert_has_calls(expected_upload_calls)