diff --git a/src/ot_orchestration/task_groups/gwas_catalog/batch_processing.py b/src/ot_orchestration/task_groups/gwas_catalog/batch_processing.py index 6449c3d..72c82b1 100644 --- a/src/ot_orchestration/task_groups/gwas_catalog/batch_processing.py +++ b/src/ot_orchestration/task_groups/gwas_catalog/batch_processing.py @@ -10,7 +10,6 @@ from airflow.operators.python import get_current_context from airflow.utils.helpers import chain from ot_orchestration.utils import create_batch_job, create_task_spec -from ot_orchestration.utils import GCSPath from airflow.models import TaskInstance import logging import time @@ -23,8 +22,10 @@ def gwas_catalog_batch_processing() -> None: @task(task_id="get_manifests_from_preparation", multiple_outputs=True) def get_batch_task_inputs( task_instance: TaskInstance | None = None, - ) -> list[GCSPath]: + ) -> dict[str, str]: """Get manifests from preparation step.""" + if task_instance is None: + raise ValueError("Task instance is None") manifest_paths = task_instance.xcom_pull( task_ids="manifest_preparation.choose_manifest_paths" ) @@ -38,9 +39,7 @@ def get_batch_task_inputs( batch_inputs = get_batch_task_inputs() @task(task_id="batch_job", multiple_outputs=True) - def execute_batch_job( - manifest_paths: list[str], config_path: str - ) -> CloudBatchSubmitJobOperator: + def execute_batch_job(manifest_paths: list[str], config_path: str): """Create a harmonisation batch job.""" params = get_step_params("batch_processing") logging.info("PARAMS: %s", params) diff --git a/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py b/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py index 58b9d2c..a20107f 100644 --- a/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py +++ b/src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py @@ -3,7 +3,7 @@ from airflow.decorators import task, task_group from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator from ot_orchestration.types import Manifest_Object -from ot_orchestration.utils import GCSIOManager, get_step_params, get_full_config +from ot_orchestration.utils import IOManager, get_step_params, get_full_config from airflow.models.baseoperator import chain from ot_orchestration.utils.manifest import extract_study_id_from_path from airflow.utils.edgemodifier import Label @@ -57,6 +57,8 @@ def get_new_sumstat_paths( def collect_sumstats_and_generate_new_manifests( ti: TaskInstance | None = None, ) -> list[Manifest_Object]: + if ti is None: + raise ValueError("Task instance is None") task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode") logging.info("TASK ID: %s", task_id) new_sumstats = ti.xcom_pull(task_ids=task_id) @@ -97,6 +99,8 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]): params = get_step_params("manifest_preparation") logging.info("USING FOLLOWING PARAMS: %s", params) curation_path = params["manual_curation_manifest"] + if not isinstance(curation_path, str): + raise ValueError("Curation path is not a string") logging.info("CURATING MANIFESTS WITH: %s", curation_path) curation_df = pd.read_csv(curation_path, sep="\t").drop( columns=["publicationTitle", "traitFromSource", "qualityControl"] @@ -116,18 +120,20 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]): def read_manifests(manifest_paths: list[str]) -> list[Manifest_Object]: """Read manifests.""" manifest_paths = [f"gs://{path}" for path in manifest_paths] - return GCSIOManager().load_many(manifest_paths) + return IOManager().load_many(manifest_paths) @task(task_id="save_config") def save_config(task_instance: TaskInstance | None = None) -> str: """Save configuration for batch processing.""" + if task_instance is None: + raise ValueError("Task instance is None") run_id = task_instance.run_id params = get_step_params("manifest_preparation") full_config = get_full_config().serialize() config_path = f"gs://{params['staging_bucket']}/{params['staging_prefix']}/{run_id}/config.yaml" logging.info("DUMPING CONFIG TO THE FOLLOWING PATH: %s", config_path) - GCSIOManager().dump(gcs_path=config_path, data=full_config) + IOManager().resolve(config_path).dump(full_config) return config_path @@ -137,6 +143,8 @@ def save_config(task_instance: TaskInstance | None = None) -> str: ) def choose_manifest_paths(ti: TaskInstance | None = None) -> list[str]: """Choose manifests to pass to the next.""" + if ti is None: + raise ValueError("Task instance is None") task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode") logging.info("TASK ID: %s", task_id) if not task_id.endswith("read_manifests"): @@ -150,7 +158,7 @@ def save_manifests(manifests: list[Manifest_Object]) -> list[Manifest_Object]: """Write manifests to persistant storage.""" manifest_paths = [manifest["manifestPath"] for manifest in manifests] logging.info("MANIFEST PATHS: %s", manifest_paths) - GCSIOManager().dump_many(manifests, manifest_paths) + IOManager().dump_many(manifests, manifest_paths) return manifests @@ -164,7 +172,6 @@ def exit_when_no_new_sumstats(new_sumstats: dict[str, str]) -> bool: @task_group(group_id=TASK_GROUP_ID) def gwas_catalog_manifest_preparation(): """Prepare initial manifest.""" - fetch_existing_manifests = GCSListObjectsOperator( task_id="list_existing_manifests", bucket="{{ params.steps.manifest_preparation.staging_bucket }}",