Skip to content

Commit

Permalink
feat: improved dag
Browse files Browse the repository at this point in the history
  • Loading branch information
Szymon Szyszkowski committed Jul 31, 2024
1 parent b4615fd commit 996b0d0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from airflow.operators.python import get_current_context
from airflow.utils.helpers import chain
from ot_orchestration.utils import create_batch_job, create_task_spec
from ot_orchestration.utils import GCSPath
from airflow.models import TaskInstance
import logging
import time
Expand All @@ -23,8 +22,10 @@ def gwas_catalog_batch_processing() -> None:
@task(task_id="get_manifests_from_preparation", multiple_outputs=True)
def get_batch_task_inputs(
task_instance: TaskInstance | None = None,
) -> list[GCSPath]:
) -> dict[str, str]:
"""Get manifests from preparation step."""
if task_instance is None:
raise ValueError("Task instance is None")
manifest_paths = task_instance.xcom_pull(
task_ids="manifest_preparation.choose_manifest_paths"
)
Expand All @@ -38,9 +39,7 @@ def get_batch_task_inputs(
batch_inputs = get_batch_task_inputs()

@task(task_id="batch_job", multiple_outputs=True)
def execute_batch_job(
manifest_paths: list[str], config_path: str
) -> CloudBatchSubmitJobOperator:
def execute_batch_job(manifest_paths: list[str], config_path: str):
"""Create a harmonisation batch job."""
params = get_step_params("batch_processing")
logging.info("PARAMS: %s", params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from airflow.decorators import task, task_group
from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator
from ot_orchestration.types import Manifest_Object
from ot_orchestration.utils import GCSIOManager, get_step_params, get_full_config
from ot_orchestration.utils import IOManager, get_step_params, get_full_config
from airflow.models.baseoperator import chain
from ot_orchestration.utils.manifest import extract_study_id_from_path
from airflow.utils.edgemodifier import Label
Expand Down Expand Up @@ -57,6 +57,8 @@ def get_new_sumstat_paths(
def collect_sumstats_and_generate_new_manifests(
ti: TaskInstance | None = None,
) -> list[Manifest_Object]:
if ti is None:
raise ValueError("Task instance is None")
task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode")
logging.info("TASK ID: %s", task_id)
new_sumstats = ti.xcom_pull(task_ids=task_id)
Expand Down Expand Up @@ -97,6 +99,8 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]):
params = get_step_params("manifest_preparation")
logging.info("USING FOLLOWING PARAMS: %s", params)
curation_path = params["manual_curation_manifest"]
if not isinstance(curation_path, str):
raise ValueError("Curation path is not a string")
logging.info("CURATING MANIFESTS WITH: %s", curation_path)
curation_df = pd.read_csv(curation_path, sep="\t").drop(
columns=["publicationTitle", "traitFromSource", "qualityControl"]
Expand All @@ -116,18 +120,20 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]):
def read_manifests(manifest_paths: list[str]) -> list[Manifest_Object]:
"""Read manifests."""
manifest_paths = [f"gs://{path}" for path in manifest_paths]
return GCSIOManager().load_many(manifest_paths)
return IOManager().load_many(manifest_paths)


@task(task_id="save_config")
def save_config(task_instance: TaskInstance | None = None) -> str:
"""Save configuration for batch processing."""
if task_instance is None:
raise ValueError("Task instance is None")
run_id = task_instance.run_id
params = get_step_params("manifest_preparation")
full_config = get_full_config().serialize()
config_path = f"gs://{params['staging_bucket']}/{params['staging_prefix']}/{run_id}/config.yaml"
logging.info("DUMPING CONFIG TO THE FOLLOWING PATH: %s", config_path)
GCSIOManager().dump(gcs_path=config_path, data=full_config)
IOManager().resolve(config_path).dump(full_config)
return config_path


Expand All @@ -137,6 +143,8 @@ def save_config(task_instance: TaskInstance | None = None) -> str:
)
def choose_manifest_paths(ti: TaskInstance | None = None) -> list[str]:
"""Choose manifests to pass to the next."""
if ti is None:
raise ValueError("Task instance is None")
task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode")
logging.info("TASK ID: %s", task_id)
if not task_id.endswith("read_manifests"):
Expand All @@ -150,7 +158,7 @@ def save_manifests(manifests: list[Manifest_Object]) -> list[Manifest_Object]:
"""Write manifests to persistant storage."""
manifest_paths = [manifest["manifestPath"] for manifest in manifests]
logging.info("MANIFEST PATHS: %s", manifest_paths)
GCSIOManager().dump_many(manifests, manifest_paths)
IOManager().dump_many(manifests, manifest_paths)
return manifests


Expand All @@ -164,7 +172,6 @@ def exit_when_no_new_sumstats(new_sumstats: dict[str, str]) -> bool:
@task_group(group_id=TASK_GROUP_ID)
def gwas_catalog_manifest_preparation():
"""Prepare initial manifest."""

fetch_existing_manifests = GCSListObjectsOperator(
task_id="list_existing_manifests",
bucket="{{ params.steps.manifest_preparation.staging_bucket }}",
Expand Down

0 comments on commit 996b0d0

Please sign in to comment.