Skip to content

Commit

Permalink
Support loading default_args from shared defaults.yml (#330)
Browse files Browse the repository at this point in the history
closes: #297

Currently, DAG Factory provides two places to configure default_args:

1. At the top of the YML file.
2. In the DAG configuration YML.

The second option overrides the first one.

Post this PR, the user can also keep the default_args in the
`defaults.yml` file. The configuration from `defaults.yml` will be
applied to all DAG Factory DAG

Sample `defaults.yml`
```
default_args:
  start_date: "2025-01-01"
  owner: "global_owner"
  depends_on_past: true
```

The precedence for default_args will be as follows, after this
implementation:

1. At DAG configuration YML
2. At the top of the YML file
3. default.yml

i.e At DAG configuration YML will take precedence.
  • Loading branch information
pankajastro authored Jan 3, 2025
1 parent 38555b1 commit fd685b2
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
30 changes: 26 additions & 4 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ class DagFactory:
:type config: dict
"""

def __init__(self, config_filepath: Optional[str] = None, config: Optional[dict] = None) -> None:
def __init__(
self,
config_filepath: Optional[str] = None,
config: Optional[dict] = None,
default_args_config_path: str = airflow_conf.get("core", "dags_folder"),
) -> None:
assert bool(config_filepath) ^ bool(config), "Either `config_filepath` or `config` should be provided"
self.default_args_config_path = default_args_config_path
if config_filepath:
DagFactory._validate_config_filepath(config_filepath=config_filepath)
self.config: Dict[str, Any] = DagFactory._load_config(config_filepath=config_filepath)
if config:
self.config: Dict[str, Any] = config

def _global_default_args(self):
default_args_yml = Path(self.default_args_config_path) / "defaults.yml"

if default_args_yml.exists():
with open(default_args_yml, "r") as file:
return yaml.safe_load(file)

@staticmethod
def _serialise_config_md(dag_name, dag_config, default_config):
# Remove empty task_groups if it exists
Expand Down Expand Up @@ -111,8 +124,15 @@ def get_default_config(self) -> Dict[str, Any]:
def build_dags(self) -> Dict[str, DAG]:
"""Build DAGs using the config file."""
dag_configs: Dict[str, Dict[str, Any]] = self.get_dag_configs()
global_default_args = self._global_default_args()
default_config: Dict[str, Any] = self.get_default_config()

if global_default_args is not None:
if "default_args" in default_config and "default_args" in global_default_args:
default_config = {
"default_args": {**global_default_args["default_args"], **default_config["default_args"]}
}

dags: Dict[str, Any] = {}

for dag_name, dag_config in dag_configs.items():
Expand Down Expand Up @@ -179,6 +199,7 @@ def clean_dags(self, globals: Dict[str, Any]) -> None:
def load_yaml_dags(
globals_dict: Dict[str, Any],
dags_folder: str = airflow_conf.get("core", "dags_folder"),
default_args_config_path: str = airflow_conf.get("core", "dags_folder"),
suffix=None,
):
"""
Expand All @@ -189,8 +210,9 @@ def load_yaml_dags(
interesting to load only a subset by setting a different suffix.
:param globals_dict: The globals() from the file used to generate DAGs
:dags_folder: Path to the folder you want to get recursively scanned
:suffix: file suffix to filter `in` what files to scan for dags
:param dags_folder: Path to the folder you want to get recursively scanned
:param default_args_config_path: The Folder path where defaults.yml exist.
:param suffix: file suffix to filter `in` what files to scan for dags
"""
# chain all file suffixes in a single iterator
logging.info("Loading DAGs from %s", dags_folder)
Expand All @@ -203,7 +225,7 @@ def load_yaml_dags(
config_file_abs_path = str(config_file_path.absolute())
logging.info("Loading %s", config_file_abs_path)
try:
factory = DagFactory(config_file_abs_path)
factory = DagFactory(config_file_abs_path, default_args_config_path=default_args_config_path)
factory.generate_dags(globals_dict)
except Exception: # pylint: disable=broad-except
logging.exception("Failed to load dag from %s", config_file_path)
Expand Down
3 changes: 3 additions & 0 deletions dev/dags/defaults.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
default_args:
start_date: "2025-01-01"
owner: "global_owner"
4 changes: 4 additions & 0 deletions tests/fixtures/defaults.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
default_args:
start_date: "2025-01-01"
owner: "global_owner"
depends_on_past: true
9 changes: 9 additions & 0 deletions tests/test_dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TEST_DAG_FACTORY = os.path.join(here, "fixtures/dag_factory.yml")
INVALID_YAML = os.path.join(here, "fixtures/invalid_yaml.yml")
INVALID_DAG_FACTORY = os.path.join(here, "fixtures/invalid_dag_factory.yml")
DEFAULT_ARGS_CONFIG_ROOT = os.path.join(here, "fixtures/")
DAG_FACTORY_KUBERNETES_POD_OPERATOR = os.path.join(here, "fixtures/dag_factory_kubernetes_pod_operator.yml")
DAG_FACTORY_VARIABLES_AS_ARGUMENTS = os.path.join(here, "fixtures/dag_factory_variables_as_arguments.yml")

Expand Down Expand Up @@ -448,6 +449,14 @@ def test_set_callback_after_loading_config():
td.generate_dags(globals())


def test_build_dag_with_global_default():
dags = dagfactory.DagFactory(
config=DAG_FACTORY_CONFIG, default_args_config_path=DEFAULT_ARGS_CONFIG_ROOT
).build_dags()

assert dags.get("example_dag").tasks[0].depends_on_past == True


def test_load_invalid_yaml_logs_error(caplog):
caplog.set_level(logging.ERROR)
load_yaml_dags(
Expand Down

0 comments on commit fd685b2

Please sign in to comment.