diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index c5eeb5d88..d20a7de22 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -132,6 +132,27 @@ def create_test_task_metadata( ) +def _get_task_id_and_args( + node: DbtNode, + args: dict[str, Any], + use_task_group: bool, + normalize_task_id: Callable[..., Any] | None, + resource_suffix: str, +) -> tuple[str, dict[str, Any]]: + """ + Generate task ID and update args with display name if needed. + """ + args_update = args + if use_task_group: + task_id = resource_suffix + elif normalize_task_id: + task_id = normalize_task_id(node) + args_update["task_display_name"] = f"{node.name}_{resource_suffix}" + else: + task_id = f"{node.name}_{resource_suffix}" + return task_id, args_update + + def create_dbt_resource_to_class(test_behavior: TestBehavior) -> dict[str, str]: """ Return the map from dbt node type to Cosmos class prefix that should be used @@ -164,6 +185,7 @@ def create_task_metadata( dbt_dag_task_group_identifier: str, use_task_group: bool = False, source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE, + normalize_task_id: Callable[..., Any] | None = None, test_behavior: TestBehavior = TestBehavior.AFTER_ALL, on_warning_callback: Callable[..., Any] | None = None, ) -> TaskMetadata | None: @@ -194,10 +216,7 @@ def create_task_metadata( if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES: task_id = f"{node.name}_{node.resource_type.value}_build" elif node.resource_type == DbtResourceType.MODEL: - if use_task_group: - task_id = "run" - else: - task_id = f"{node.name}_run" + task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "run") elif node.resource_type == DbtResourceType.SOURCE: args["on_warning_callback"] = on_warning_callback @@ -207,18 +226,21 @@ def create_task_metadata( and node.has_test is False ): return None - task_id = f"{node.name}_source" args["select"] = f"source:{node.resource_name}" args.pop("models") - if use_task_group is True: - task_id = node.resource_type.value + task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "source") if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL: # render sources without freshness as empty operators - return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator") + # empty operator does not accept custom parameters (e.g., profile_args). recreate the args. + if "task_display_name" in args: + args = {"task_display_name": args["task_display_name"]} + else: + args = {} + return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator", arguments=args) else: - task_id = f"{node.name}_{node.resource_type.value}" - if use_task_group is True: - task_id = node.resource_type.value + task_id, args = _get_task_id_and_args( + node, args, use_task_group, normalize_task_id, node.resource_type.value + ) task_metadata = TaskMetadata( id=task_id, @@ -250,6 +272,7 @@ def generate_task_or_group( source_rendering_behavior: SourceRenderingBehavior, test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None, + normalize_task_id: Callable[..., Any] | None = None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None @@ -267,6 +290,7 @@ def generate_task_or_group( dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group), use_task_group=use_task_group, source_rendering_behavior=source_rendering_behavior, + normalize_task_id=normalize_task_id, test_behavior=test_behavior, on_warning_callback=on_warning_callback, ) @@ -371,6 +395,7 @@ def build_airflow_graph( node_converters = render_config.node_converters or {} test_behavior = render_config.test_behavior source_rendering_behavior = render_config.source_rendering_behavior + normalize_task_id = render_config.normalize_task_id tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {} task_or_group: TaskGroup | BaseOperator @@ -392,6 +417,7 @@ def build_airflow_graph( source_rendering_behavior=source_rendering_behavior, test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, + normalize_task_id=normalize_task_id, node=node, ) if task_or_group is not None: diff --git a/cosmos/config.py b/cosmos/config.py index 516a6787b..59f857114 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -62,6 +62,8 @@ class RenderConfig: :param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``. :param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4). :param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). + :param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache. + :param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name. """ emit_datasets: bool = True @@ -80,6 +82,7 @@ class RenderConfig: enable_mock_profile: bool = True source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list) + normalize_task_id: Callable[..., Any] | None = None def __post_init__(self, dbt_project_path: str | Path | None) -> None: if self.env_vars: diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index 9001b4c2e..55ba51787 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -28,4 +28,5 @@ Cosmos offers a number of configuration options to customize its behavior. For m Compiled SQL Logging Caching + Task display name Callbacks diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index 068998de5..745b7018c 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -18,6 +18,8 @@ The ``RenderConfig`` class takes the following arguments: - ``env_vars``: (available in v1.2.5, use``ProjectConfig.env_vars`` for v1.3.0 onwards) A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``. - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` - ``airflow_vars_to_purge_cache``: (new in v1.5) Specify Airflow variables that will affect the ``LoadMode.DBT_LS`` cache. See `Caching <./caching.html>`_ for more information. +- ``source_rendering_behavior``: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information. +- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task’s display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task’s display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information. Customizing how nodes are rendered (experimental) ------------------------------------------------- diff --git a/docs/configuration/task-display-name.rst b/docs/configuration/task-display-name.rst new file mode 100644 index 000000000..56c750dd2 --- /dev/null +++ b/docs/configuration/task-display-name.rst @@ -0,0 +1,33 @@ +.. _task-display-name: + +Task display name +================ + +.. note:: + This feature is only available for Airflow >= 2.9. + +In Airflow, ``task_id`` does not support non-ASCII characters. Therefore, if users wish to use non-ASCII characters (such as their native language) as display names while keeping ``task_id`` in ASCII, they can use the ``display_name`` parameter. + +To work with projects that use non-ASCII characters in model names, the ``normalize_task_id`` field of ``RenderConfig`` can be utilized. + +Example: + +You can provide a function to convert the model name to an ASCII-compatible format. The function’s output is used as the TaskID, while the display name on Airflow remains as the original model name. + +.. code-block:: python + + from slugify import slugify + + + def normalize_task_id(node): + return slugify(node.name) + + + from cosmos import DbtTaskGroup, RenderConfig + + jaffle_shop = DbtTaskGroup( + render_config=RenderConfig(normalize_task_id=normalize_task_id) + ) + +.. note:: + Although the slugify example often works, it may not be suitable for use in actual production. Since slugify performs conversions based on pronunciation, there may be cases where task_id is not unique due to homophones and similar issues. diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index fc0070e8b..c00f0cf53 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -611,6 +611,123 @@ def test_create_task_metadata_snapshot(caplog): assert metadata.arguments == {"models": "my_snapshot"} +def _normalize_task_id(node: DbtNode) -> str: + """for test_create_task_metadata_normalize_task_id""" + return f"new_task_id_{node.name}_{node.resource_type.value}" + + +@pytest.mark.skipif( + version.parse(airflow_version) < version.parse("2.9"), + reason="Airflow task did not have display_name until the 2.9 release", +) +@pytest.mark.parametrize( + "node_type,node_id,normalize_task_id,use_task_group,expected_node_id,expected_display_name", + [ + # normalize_task_id is None (default) + ( + DbtResourceType.MODEL, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + None, + False, + "test_node_run", + None, + ), + ( + DbtResourceType.SOURCE, + f"{DbtResourceType.SOURCE.value}.my_folder.test_node", + None, + False, + "test_node_source", + None, + ), + ( + DbtResourceType.SEED, + f"{DbtResourceType.SEED.value}.my_folder.test_node", + None, + False, + "test_node_seed", + None, + ), + # normalize_task_id is passed and use_task_group is False + ( + DbtResourceType.MODEL, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + False, + "new_task_id_test_node_model", + "test_node_run", + ), + ( + DbtResourceType.SOURCE, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + False, + "new_task_id_test_node_source", + "test_node_source", + ), + ( + DbtResourceType.SEED, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + False, + "new_task_id_test_node_seed", + "test_node_seed", + ), + # normalize_task_id is passed and use_task_group is True + ( + DbtResourceType.MODEL, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + True, + "run", + None, + ), + ( + DbtResourceType.SOURCE, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + True, + "source", + None, + ), + ( + DbtResourceType.SEED, + f"{DbtResourceType.MODEL.value}.my_folder.test_node", + _normalize_task_id, + True, + "seed", + None, + ), + ], +) +def test_create_task_metadata_normalize_task_id( + node_type, node_id, normalize_task_id, use_task_group, expected_node_id, expected_display_name +): + node = DbtNode( + unique_id=node_id, + resource_type=node_type, + depends_on=[], + file_path="", + tags=[], + config={}, + ) + args = {} + metadata = create_task_metadata( + node, + execution_mode=ExecutionMode.LOCAL, + args=args, + dbt_dag_task_group_identifier="", + use_task_group=use_task_group, + normalize_task_id=normalize_task_id, + source_rendering_behavior=SourceRenderingBehavior.ALL, + ) + assert metadata.id == expected_node_id + if expected_display_name: + assert metadata.arguments["task_display_name"] == expected_display_name + else: + assert "task_display_name" not in metadata.arguments + + @pytest.mark.parametrize( "node_type,node_unique_id,test_indirect_selection,additional_arguments", [