diff --git a/src/ot_orchestration/cli/process_in_batch.py b/src/ot_orchestration/cli/process_in_batch.py index daead87..a66ff93 100644 --- a/src/ot_orchestration/cli/process_in_batch.py +++ b/src/ot_orchestration/cli/process_in_batch.py @@ -26,28 +26,104 @@ def harmonise(manifest: Manifest_Object) -> Manifest_Object: "+step.session.extended_spark_conf={spark.kryoserializer.buffer.max:'500m'}", "+step.session.extended_spark_conf={spark.driver.maxResultSize:'5g'}", ] + if GCSIOManager().exists(harmonised_path): + logging.info("Harmonisation result exists for %s. Skipping", study_id) + manifest["passHarmonisation"] = True + return manifest result = subprocess.run(args=command, capture_output=True) if result.returncode != 0: logging.error("Harmonisation for study %s failed!", study_id) - logging.error(result.stderr.decode()) + error_msg = result.stderr.decode() + logging.error(error_msg) manifest["passHarmonisation"] = False logging.info("Dumping manifest to %s", manifest["manifestPath"]) GCSIOManager().dump(manifest["manifestPath"], manifest) exit(1) - else: - logging.info("Harmonisation for study %s completed successfully!", study_id) - manifest["passHarmonisation"] = True + + logging.info("Harmonisation for study %s succeded!", study_id) + manifest["passHarmonisation"] = True return manifest def qc(manifest: Manifest_Object) -> Manifest_Object: """Run QC.""" + harmonised_path = manifest["harmonisedPath"] + qc_path = manifest["qcPath"] + study_id = manifest["studyId"] + command = [ + "poetry", + "run", + "gentropy", + "step=summary_statistics_qc", + f'step.gwas_path="{harmonised_path}"', + f'step.output_path="{qc_path}"', + f'step.study_id="{study_id}"', + "+step.session.extended_spark_conf={spark.jars:'https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar'}", + "+step.session.extended_spark_conf={spark.dynamicAllocation.enabled:'false'}", + "+step.session.extended_spark_conf={spark.driver.memory:'30g'}", + "+step.session.extended_spark_conf={spark.kryoserializer.buffer.max:'500m'}", + "+step.session.extended_spark_conf={spark.driver.maxResultSize:'5g'}", + ] + result_exists = GCSIOManager().exists(qc_path) + logging.info("Result exists: %s", result_exists) + if GCSIOManager().exists(qc_path): + logging.info("QC result exists for %s. Skipping", study_id) + manifest["passQC"] = True + return manifest + + result = subprocess.run(args=command, capture_output=True) + if result.returncode != 0: + logging.error("QC for study %s failed!", study_id) + error_msg = result.stderr.decode() + logging.error(error_msg) + manifest["passQC"] = False + logging.info("Dumping manifest to %s", manifest["manifestPath"]) + GCSIOManager().dump(manifest["manifestPath"], manifest) + exit(1) + + logging.info("QC for study %s succeded!", study_id) + manifest["passQC"] = True return manifest +def qc_consolidation(manifest: Manifest_Object) -> Manifest_Object: + pass + + def clumping(manifest: Manifest_Object) -> Manifest_Object: """Run Clumping.""" + harmonised_path = manifest["harmonisedPath"] + clumping_path = manifest["clumpingPath"] + study_id = manifest["studyId"] + command = [ + "poetry", + "run", + "gentropy", + "step=clumping", + f'step.gwas_path="{harmonised_path}"', + f'step.output_path="{clumping_path}"', + f'step.study_id="{study_id}"', + "+step.session.extended_spark_conf={spark.jars:'https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar'}", + "+step.session.extended_spark_conf={spark.dynamicAllocation.enabled:'false'}", + "+step.session.extended_spark_conf={spark.driver.memory:'30g'}", + "+step.session.extended_spark_conf={spark.kryoserializer.buffer.max:'500m'}", + "+step.session.extended_spark_conf={spark.driver.maxResultSize:'5g'}", + ] + if GCSIOManager().exists(clumping_path): + logging.info("Clumping result exists for %s. Skipping", study_id) + manifest["passClumping"] = True + return manifest + + result = subprocess.run(args=command, capture_output=True) + if result.returncode != 0: + logging.error("Clumping for study %s failed!", study_id) + error_msg = result.stderr.decode() + logging.error(error_msg) + manifest["passClumping"] = False + logging.info("Dumping manifest to %s", manifest["manifestPath"]) + GCSIOManager().dump(manifest["manifestPath"], manifest) + exit(1) return manifest diff --git a/src/ot_orchestration/dags/branching.py b/src/ot_orchestration/dags/branching.py new file mode 100644 index 0000000..4918461 --- /dev/null +++ b/src/ot_orchestration/dags/branching.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the BranchPythonOperator.""" + +import random +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task +from airflow.operators.empty import EmptyOperator +from airflow.utils.edgemodifier import Label +from airflow.utils.trigger_rule import TriggerRule + +with DAG( + dag_id="example_branch_python_operator_decorator", + start_date=datetime(2021, 1, 1), + catchup=False, + schedule_interval="@daily", + tags=["example", "example2"], +) as dag: + run_this_first = EmptyOperator( + task_id="run_this_first", + ) + + options = ["branch_a", "branch_b", "branch_c", "branch_d"] + + @task.branch(task_id="branching") + def random_choice(): + return random.choice(options) + + random_choice_instance = random_choice() + + run_this_first >> random_choice_instance + + join = EmptyOperator( + task_id="join", + trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, + ) + + for option in options: + if option == "branch_a": + empty_follow_3 = EmptyOperator( + task_id="follow_" + option, + ) + random_choice_instance >> Label(option) >> empty_follow_3 >> join + if option == "branch_b": + empty_follow_2 = EmptyOperator( + task_id="follow_" + option, + ) + random_choice_instance >> Label(option) >> empty_follow_2 >> join + if option == "branch_c": + random_choice_instance >> Label(option) >> join + + if option == "branch_d": + t = EmptyOperator( + task_id=option, + ) + + empty_follow = EmptyOperator( + task_id="follow_" + option, + ) + + # Label is optional here, but it can help identify more complex branches + random_choice_instance >> Label(option) >> t >> empty_follow diff --git a/src/ot_orchestration/dags/gwas_catalog_dag.py b/src/ot_orchestration/dags/gwas_catalog_dag.py index 9fd4566..7f19a64 100644 --- a/src/ot_orchestration/dags/gwas_catalog_dag.py +++ b/src/ot_orchestration/dags/gwas_catalog_dag.py @@ -12,20 +12,19 @@ from airflow.utils.helpers import chain +from ot_orchestration.utils.common import shared_dag_kwargs -RUN_DATE = datetime.today() - config_path = "/opt/airflow/config/config.yaml" config = QRCP.from_file(config_path).serialize() -@dag(start_date=RUN_DATE, dag_id="GWAS_Catalog_dag", schedule="@once", params=config) +@dag(dag_id="GWAS_Catalog_dag", params=config, **shared_dag_kwargs) def gwas_catalog_dag() -> None: """GWAS catalog DAG.""" chain( gwas_catalog_manifest_preparation(), - gwas_catalog_batch_processing(), + # gwas_catalog_batch_processing(), ) diff --git a/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py b/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py index 1419a00..5b0ef61 100644 --- a/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py +++ b/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py @@ -4,12 +4,13 @@ from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator from ot_orchestration.types import Manifest_Object from ot_orchestration.utils import GCSIOManager, get_step_params, get_full_config -from airflow.utils.helpers import chain +from airflow.models.baseoperator import chain from ot_orchestration.utils.manifest import extract_study_id_from_path import logging import pandas as pd from airflow.models.taskinstance import TaskInstance - +from airflow.utils.trigger_rule import TriggerRule +from airflow.operators.empty import EmptyOperator FILTER_FILE = "/opt/airflow/config/filter.csv" TASK_GROUP_ID = "manifest_preparation" @@ -18,8 +19,6 @@ @task_group(group_id=TASK_GROUP_ID) def gwas_catalog_manifest_preparation(): """Prepare initial manifest.""" - options = ["FORCE", "RESUME", "CONTINUE"] - print(options) existing_manifest_paths = GCSListObjectsOperator( task_id="list_existing_manifests", bucket="{{ params.steps.manifest_preparation.staging_bucket }}", @@ -34,93 +33,32 @@ def gwas_catalog_manifest_preparation(): match_glob="**/*.h.tsv.gz", ).output - @task(task_id="get_new_sumstats") - def get_new_sumstats( - raw_sumstats_paths: list[str], existing_manifest_paths: list[str] - ) -> dict[str, str]: - """Get new sumstats.""" - processed = {extract_study_id_from_path(p) for p in existing_manifest_paths} - logging.info("ALREADY PROCESSED STUDIES: %s", len(processed)) - all = {extract_study_id_from_path(p): p for p in raw_sumstats_paths} - logging.info("ALL STUDIES (INCLUDING NOT PROCESSED): %s", len(all)) - new = {key: val for key, val in all.items() if key not in processed} - logging.info("NEW STUDIES UNPROCESSED: %s", len(new)) - return new + @task.branch(task_id="get_execution_mode") + def get_execution_mode(): + """Get execution mode.""" + mode = get_full_config().config.mode.lower() + return f"manifest_preparation.{mode}" - new_sumstats = get_new_sumstats(raw_sumstats_paths, existing_manifest_paths) + execution_mode = get_execution_mode() - @task(task_id="generate_new_manifests") - def generate_new_manifests(new_sumstats: dict[str, str]) -> list[Manifest_Object]: - """Task to generate manifest files for the new studies.""" - params = get_step_params("manifest_preparation") - # params from the configuration - logging.info("USING FOLLOWING PARAMS: %s", params) - raw_sumstat_bucket = params["raw_sumstats_bucket"] - staging_bucket = params["staging_bucket"] - staging_prefix = params["staging_prefix"] - harmonised_prefix = params["harmonised_result_path_prefix"] - qc_prefix = params["qc_result_path_prefix"] - - # prepare manifests for the new studies - manifests = [] - for study_id, raw_sumstat_path in new_sumstats.items(): - staging_path = f"{staging_bucket}/{staging_prefix}/{study_id}" - partial_manifest = { - "studyId": study_id, - "rawPath": f"gs://{raw_sumstat_bucket}/{raw_sumstat_path}", - "manifestPath": f"gs://{staging_path}/manifest.json", - "harmonisedPath": f"gs://{staging_path}/{harmonised_prefix}", - "qcPath": f"gs://{staging_path}/{qc_prefix}", - "passHarmonisation": False, - "passQC": False, - "passClumping": False, - } - manifests.append(partial_manifest) - logging.info(partial_manifest) - return manifests - - new_manifests = generate_new_manifests(new_sumstats) - - @task(task_id="amend_curation_metadata") - def amend_curation_metadata(new_manifests: list[Manifest_Object]): - """Read curation file and add it to the partial manifests.""" - params = get_step_params("manifest_preparation") - logging.info("USING FOLLOWING PARAMS: %s", params) - curation_path = params["manual_curation_manifest"] - logging.info("CURATING MANIFESTS WITH: %s", curation_path) - curation_df = pd.read_csv(curation_path, sep="\t").drop( - columns=["publicationTitle", "traitFromSource", "qualityControl"] - ) - new_manifests = ( - pd.DataFrame.from_records(new_manifests) - .merge(curation_df, how="left", on="studyId") - .replace({float("nan"): None}) - .to_dict("records") - ) - for new_manifest in new_manifests: - logging.info("NEW MANIFESTS WITH CURATION METADATA: %s", new_manifest) - return new_manifests - - new_manifests_with_curation = amend_curation_metadata(new_manifests) - - @task(task_id="save_manifests") - def save_manifests(manifests: list[Manifest_Object]) -> list[Manifest_Object]: - """Write manifests to persistant storage.""" - manifest_paths = [manifest["manifestPath"] for manifest in manifests] - logging.info("MANIFEST PATHS: %s", manifest_paths) - GCSIOManager().dump_many(manifests, manifest_paths) - return manifests - - saved_manifests = save_manifests(new_manifests_with_curation) - - @task(task_id="choose_manifest_paths") - def choose_manifest_paths(manifests: list[Manifest_Object]) -> list[str]: + @task( + task_id="choose_manifest_paths", + trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, + ) + def choose_manifest_paths(ti: TaskInstance | None = None) -> list[str]: """Choose manifests to pass to the next.""" + branch_name = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode") + logging.info("BRANCH NAME: %s", branch_name) + manifest_generation_task = f"{branch_name}.save_manifests" + if branch_name == "manifest_preparation.resume": + manifest_generation_task = "manifest_preparation.read_manifests" + logging.info("MANIFEST GENERATION TASK: %s", manifest_generation_task) + manifests = ti.xcom_pull(task_ids=manifest_generation_task) return [ manifest["manifestPath"] for manifest in manifests if manifest["isCurated"] ] - filtered_manifests = choose_manifest_paths(saved_manifests) + filtered_manifests = choose_manifest_paths() @task(task_id="save_config") def save_config(task_instance: TaskInstance | None = None) -> str: @@ -135,14 +73,146 @@ def save_config(task_instance: TaskInstance | None = None) -> str: saved_config_path = save_config() - chain( - [existing_manifest_paths, raw_sumstats_paths], - new_manifests, - new_manifests_with_curation, - saved_manifests, - filtered_manifests, - saved_config_path, - ) + @task(task_id="get_all_sumstats") + def get_all_sumstats( + raw_sumstats_paths: list[str], + ) -> dict[str, str]: + """Get all sumstats.""" + return {extract_study_id_from_path(p): p for p in raw_sumstats_paths} + + @task(task_id="get_new_sumstats") + def get_new_sumstats( + raw_sumstats_paths: list[str], + existing_manifest_paths: list[str], + ) -> dict[str, str]: + """Get new sumstats.""" + processed = {extract_study_id_from_path(p) for p in existing_manifest_paths} + logging.info("ALREADY PROCESSED STUDIES: %s", len(processed)) + all = {extract_study_id_from_path(p): p for p in raw_sumstats_paths} + logging.info("ALL STUDIES (INCLUDING NOT PROCESSED): %s", len(all)) + new = {key: val for key, val in all.items() if key not in processed} + logging.info("NEW STUDIES UNPROCESSED: %s", len(new)) + return new + + @task(task_id="read_manifests") + def read_manifests(manifest_paths: list[str]) -> list[Manifest_Object]: + """Read manifests.""" + manifest_paths = [f"gs://{path}" for path in manifest_paths] + return GCSIOManager().load_many(manifest_paths) + + for mode in ["force", "resume", "continue"]: + branching_start = EmptyOperator(task_id=mode.lower()) + + @task(task_id=f"{mode}.amend_curation_metadata") + def amend_curation_metadata(new_manifests: list[Manifest_Object]): + """Read curation file and add it to the partial manifests.""" + if new_manifests == []: + return new_manifests + params = get_step_params("manifest_preparation") + logging.info("USING FOLLOWING PARAMS: %s", params) + curation_path = params["manual_curation_manifest"] + logging.info("CURATING MANIFESTS WITH: %s", curation_path) + curation_df = pd.read_csv(curation_path, sep="\t").drop( + columns=["publicationTitle", "traitFromSource", "qualityControl"] + ) + new_manifests = ( + pd.DataFrame.from_records(new_manifests) + .merge(curation_df, how="left", on="studyId") + .replace({float("nan"): None}) + .to_dict("records") + ) + for new_manifest in new_manifests: + logging.info("NEW MANIFESTS WITH CURATION METADATA: %s", new_manifest) + return new_manifests + + @task(task_id=f"{mode}.genereate_new_manifests") + def generate_new_manifests( + new_sumstats: dict[str, str], + ) -> list[Manifest_Object]: + """Task to generate manifest files for the new studies.""" + params = get_step_params("manifest_preparation") + # params from the configuration + logging.info("USING FOLLOWING PARAMS: %s", params) + raw_sumstat_bucket = params["raw_sumstats_bucket"] + staging_bucket = params["staging_bucket"] + staging_prefix = params["staging_prefix"] + harmonised_prefix = params["harmonised_result_path_prefix"] + qc_prefix = params["qc_result_path_prefix"] + + # prepare manifests for the new studies + manifests = [] + for study_id, raw_sumstat_path in new_sumstats.items(): + staging_path = f"{staging_bucket}/{staging_prefix}/{study_id}" + partial_manifest = { + "studyId": study_id, + "rawPath": f"gs://{raw_sumstat_bucket}/{raw_sumstat_path}", + "manifestPath": f"gs://{staging_path}/manifest.json", + "harmonisedPath": f"gs://{staging_path}/{harmonised_prefix}", + "qcPath": f"gs://{staging_path}/{qc_prefix}", + "passHarmonisation": False, + "passQC": False, + "passClumping": False, + } + manifests.append(partial_manifest) + logging.info(partial_manifest) + return manifests + + @task(task_id=f"{mode}.save_manifests") + def save_manifests(manifests: list[Manifest_Object]) -> list[Manifest_Object]: + """Write manifests to persistant storage.""" + manifest_paths = [manifest["manifestPath"] for manifest in manifests] + logging.info("MANIFEST PATHS: %s", manifest_paths) + GCSIOManager().dump_many(manifests, manifest_paths) + return manifests + + if mode == "resume": + manifests = read_manifests(existing_manifest_paths) + chain( + execution_mode, + branching_start, + manifests, + filtered_manifests, + ) + + if mode == "force": + new_sumstats = get_all_sumstats(raw_sumstats_paths) + new_manifests = generate_new_manifests(new_sumstats) + new_manifests_with_curation = amend_curation_metadata(new_manifests) + manifests = save_manifests(new_manifests_with_curation) + chain( + execution_mode, + branching_start, + new_sumstats, + new_manifests, + new_manifests_with_curation, + manifests, + filtered_manifests, + ) + if mode == "continue": + + @task.short_circuit(task_id=f"{mode}.exit_when_no_new_sumstats") + def exit_when_no_new_sumstats(new_sumstats: dict[str, str]) -> bool: + """Exit when no new sumstats.""" + logging.info("NEW SUMSTATS: %s", new_sumstats) + return new_sumstats + + new_sumstats = get_new_sumstats(raw_sumstats_paths, existing_manifest_paths) + new_manifests = generate_new_manifests(new_sumstats) + new_manifests_with_curation = amend_curation_metadata(new_manifests) + manifests = save_manifests(new_manifests_with_curation) + no_new_sumstats = exit_when_no_new_sumstats(new_sumstats) + chain( + execution_mode, + branching_start, + new_sumstats, + no_new_sumstats, + new_manifests, + new_manifests_with_curation, + manifests, + filtered_manifests, + ) + + chain(filtered_manifests, saved_config_path) __all__ = ["gwas_catalog_manifest_preparation"] diff --git a/src/ot_orchestration/utils/gcs_path.py b/src/ot_orchestration/utils/gcs_path.py index 511d44c..732b44e 100644 --- a/src/ot_orchestration/utils/gcs_path.py +++ b/src/ot_orchestration/utils/gcs_path.py @@ -52,6 +52,14 @@ def __init__(self): self.client._http.mount("https://", adapter) self.client._http._auth_request.session.mount("https://", adapter) + def exists(self, gcs_path: str) -> bool: + """Check if file exists in Google Cloud Storage.""" + gcs_path = GCSPath(gcs_path) + bucket_name, file_name = gcs_path.split() + bucket = Bucket(client=self.client, name=bucket_name) + blob = Blob(name=file_name, bucket=bucket) + return blob.exists() + def dump(self, gcs_path: str, data: dict) -> None: """Write data to Google Cloud Storage.""" gcs_path = GCSPath(gcs_path)