Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KGX export #7

Merged
merged 10 commits into from
Jun 21, 2024
95 changes: 84 additions & 11 deletions DAG/targeted-export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from airflow.operators.python_operator import PythonOperator
from airflow.operators.bash_operator import BashOperator
from datetime import datetime, timedelta
from airflow import models
from airflow import models, XComArg
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)

from kubernetes.client import models as k8s_models

MYSQL_DATABASE_PASSWORD=os.environ.get('MYSQL_DATABASE_PASSWORD')
Expand All @@ -17,17 +18,23 @@
UNI_BUCKET = os.environ.get('UNI_BUCKET')
TMP_BUCKET = os.environ.get('TMP_BUCKET')
FAILURE_EMAIL = os.environ.get('FAILURE_EMAIL')
START_DATE=datetime(2024, 3, 29, 0, 0)
CHUNK_SIZE = '100000'
EVIDENCE_LIMIT = '5'
STEP_SIZE = 75000
START_DATE=datetime(2024, 6, 20, 0, 0)
EVIDENCE_LIMIT = 5
# STEP_SIZE = 75000 ### STEP_SIZE doesn't seem to be used
ASSERTION_LIMIT = 100000 # This is the default in Edgar's original implementation so keeping it for now
CHUNK_SIZE = '25000'

# # for testing
# ASSERTION_LIMIT = 25000
# CHUNK_SIZE = 5000



default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': START_DATE,
'schedule_interval': '0 23 * * 6',
# 'schedule_interval': '0 23 * * 6', # kubernetesPodOperator did not like this argument
'email': [FAILURE_EMAIL],
'email_on_failure': True,
'email_on_retry': True,
Expand Down Expand Up @@ -70,8 +77,39 @@ def output_operations(**kwargs):
with open(kwargs['output_filename'], 'w') as outfile:
x = outfile.write(json.dumps(operations_dict))

# TODO: we are able to read the assertion count after querying for it, but not sure how to get it from XComArgs as an integer to use it in generate_edge_export_arguments
def read_assertion_count_from_file(ti, **kwargs):
file_path = kwargs['file_path']
with open(file_path, "r") as count_file:
assertion_count = int(count_file.readline().strip())
print(f"===================== ASSERTION COUNT {assertion_count}")
ti.xcom_push(key='assertion_count', value=assertion_count)

# TODO: get the assertion count dynamically. It is put into XComArgs above, but I can't seem to figure out how to retrieve it as an integer
def get_assertion_count():
# return 50000
return 3261384

def generate_edge_export_arguments(assertion_limit, chunk_size, evidence_limit, bucket):
arguments_list = []

with models.DAG(dag_id='targeted-export', default_args=default_args, catchup=False) as dag:
total_assertion_count = get_assertion_count()
# total_assertion_count = total_assertion_count# get_assertion_count()
incremental_assertion_count = 0

while incremental_assertion_count < int(total_assertion_count):
arguments_list.append(['-t', 'edges',
'-b', bucket,
'--chunk_size', str(chunk_size),
'--limit', str(evidence_limit),
'--assertion_offset', str(incremental_assertion_count),
'--assertion_limit', str(assertion_limit)
])
incremental_assertion_count += assertion_limit

return arguments_list

with models.DAG(dag_id='targeted-export', schedule_interval= '0 23 * * 6', default_args=default_args, catchup=False) as dag:
filename_list = []
export_task_list = []

Expand All @@ -88,14 +126,37 @@ def output_operations(**kwargs):
'MYSQL_DATABASE_INSTANCE': MYSQL_DATABASE_INSTANCE,
},
image='gcr.io/translator-text-workflow-dev/kgx-export:latest')
export_edges = KubernetesPodOperator(

export_assertion_count = KubernetesPodOperator(
task_id='count-assertions',
name='count-assertions',
config_file="/home/airflow/composer_kube_config",
namespace='composer-user-workloads',
image_pull_policy='Always',
arguments=['-t', 'count', '-b', TMP_BUCKET],
env_vars={
'MYSQL_DATABASE_PASSWORD': MYSQL_DATABASE_PASSWORD,
'MYSQL_DATABASE_USER': MYSQL_DATABASE_USER,
'MYSQL_DATABASE_INSTANCE': MYSQL_DATABASE_INSTANCE,
},
image='gcr.io/translator-text-workflow-dev/kgx-export:latest')

read_assertion_count = PythonOperator(
task_id='read_assertion_count',
python_callable=read_assertion_count_from_file,
provide_context=True,
op_kwargs={'file_path': '/home/airflow/gcs/data/kgx-build/assertion.count'},
dag=dag)


export_edges = KubernetesPodOperator.partial(
task_id=f'targeted-edges',
name=f'edge-export',
config_file="/home/airflow/composer_kube_config",
namespace='composer-user-workloads',
image_pull_policy='Always',
startup_timeout_seconds=1200,
arguments=['-t', 'edges', '-b', TMP_BUCKET, '--chunk_size', CHUNK_SIZE, '--limit', EVIDENCE_LIMIT],
# arguments=['-t', 'edges', '-b', TMP_BUCKET, '--chunk_size', CHUNK_SIZE, '--limit', EVIDENCE_LIMIT],
env_vars={
'MYSQL_DATABASE_PASSWORD': MYSQL_DATABASE_PASSWORD,
'MYSQL_DATABASE_USER': MYSQL_DATABASE_USER,
Expand All @@ -106,7 +167,12 @@ def output_operations(**kwargs):
),
retries=1,
image='gcr.io/translator-text-workflow-dev/kgx-export:latest'
)
).expand(arguments=generate_edge_export_arguments(ASSERTION_LIMIT, CHUNK_SIZE, EVIDENCE_LIMIT, TMP_BUCKET))

cat_edge_files = BashOperator(
task_id='targeted-cat-edge-files',
bash_command=f"cd /home/airflow/gcs/data/kgx-build/ && cat edges*.tsv > edges.tsv && cp edges.tsv /home/airflow/gcs/data/kgx-export/")

generate_metadata = KubernetesPodOperator(
task_id='targeted-metadata',
name='targeted-metadata',
Expand All @@ -115,18 +181,25 @@ def output_operations(**kwargs):
image_pull_policy='Always',
arguments=['-t', 'metadata', '-b', TMP_BUCKET],
image='gcr.io/translator-text-workflow-dev/kgx-export:latest')

generate_bte_operations = PythonOperator(
task_id='generate_bte_operations',
python_callable=output_operations,
provide_context=True,
op_kwargs={'edges_filename': '/home/airflow/gcs/data/kgx-export/edges.tsv',
'output_filename': '/home/airflow/gcs/data/kgx-export/operations.json'},
dag=dag)

compress_edge_file = BashOperator(
task_id='targeted-compress',
bash_command=f"cd /home/airflow/gcs/data/kgx-export/ && gzip -f edges.tsv")

publish_files = BashOperator(
task_id='targeted-publish',
bash_command=f"gsutil cp gs://{TMP_BUCKET}/data/kgx-export/* gs://{UNI_BUCKET}/kgx/UniProt/")

clean_up = BashOperator(
task_id='clean-up',
bash_command=f"cd /home/airflow/gcs/data/kgx-build/ && rm *.tsv")

export_nodes >> export_edges >> generate_bte_operations >> compress_edge_file >> generate_metadata >> publish_files
export_nodes >> export_assertion_count >> read_assertion_count >> export_edges >> cat_edge_files >> generate_bte_operations >> compress_edge_file >> generate_metadata >> publish_files >> clean_up
8 changes: 5 additions & 3 deletions exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

GCP_BLOB_PREFIX = 'kgx/UniProt/'
GCP_BLOB_PREFIX = 'data/kgx-export/'

def export_metadata(bucket):
"""
Expand Down Expand Up @@ -95,9 +95,11 @@ def get_conn() -> pymysql.connections.Connection:
logging.info("Exporting UniProt")
if args.target == 'nodes':
targeted.export_nodes(session_maker(), bucket, GCP_BLOB_PREFIX)
else:
elif args.target == 'edges':
nodes = get_valid_nodes(bucket)
targeted.export_edges(session_maker(), nodes, bucket, GCP_BLOB_PREFIX,
targeted.export_edges(session_maker(), nodes, bucket, "data/kgx-build/",
assertion_start=args.assertion_offset, assertion_limit=args.assertion_limit,
chunk_size=args.chunk_size, edge_limit=args.limit)
elif args.target == 'count':
targeted.export_assertion_count(session_maker(), bucket, "data/kgx-build/")
logging.info("End Main")
4 changes: 2 additions & 2 deletions services.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def upload_to_gcp(bucket_name: str, source_file_name: str, destination_blob_name
"""
client = storage.Client()
bucket = client.bucket(bucket_name)
logging.info(f'Uploading {source_file_name} to {destination_blob_name}')
logging.info(f'Uploading {source_file_name} to bucket: {bucket_name} path: {destination_blob_name}')
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(source_file_name, timeout=300, num_retries=2)
if blob.exists() and os.path.isfile(source_file_name) and delete_source_file:
Expand Down Expand Up @@ -269,7 +269,7 @@ def get_assertion_json(rows):
},
{
"attribute_type_id": "biolink:agent_type",
"value": "text-mining agent"
"value": "text_mining_agent"
},
{
"attribute_type_id": "biolink:primary_knowledge_source",
Expand Down
42 changes: 42 additions & 0 deletions targeted.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,34 @@ def get_assertion_ids(session, limit=600000, offset=0):
})]



def get_assertion_count(session):
"""
Count the number of assertions that will be exported

:param session: the database session
:param limit: limit for assertion query
:param offset: offset for assertion query
:returns a list of assertion ids
"""
count_query = text('SELECT count(assertion_id) FROM targeted.assertion WHERE assertion_id NOT IN '
'(SELECT DISTINCT(assertion_id) '
'FROM assertion_evidence_feedback af '
'INNER JOIN evidence_feedback_answer ef '
'INNER JOIN evidence e ON e.evidence_id = af.evidence_id '
'INNER JOIN evidence_version ev ON ev.evidence_id = e.evidence_id '
'WHERE ef.prompt_text = \'Assertion Correct\' AND ef.response = 0 AND ev.version = 2) '
'AND subject_curie NOT IN :ex1 AND object_curie NOT IN :ex2 '
'AND subject_curie NOT IN :ex3 AND object_curie NOT IN :ex4 '
)
return [row[0] for row in session.execute(count_query, {
'ex1': EXCLUDED_FIG_CURIES,
'ex2': EXCLUDED_FIG_CURIES,
'ex3': EXCLUDE_LIST,
'ex4': EXCLUDE_LIST
})]


def get_edge_data(session: Session, id_list, chunk_size=1000, edge_limit=5) -> list[str]:
"""
Generate edge data for the given list of ids
Expand Down Expand Up @@ -286,3 +314,17 @@ def export_edges(session: Session, nodes: set, bucket: str, blob_prefix: str,
uniquify_edge_dict(edge_dict)
services.write_edges(edge_dict, nodes, output_filename)
services.upload_to_gcp(bucket, output_filename, f'{blob_prefix}{output_filename}')

def export_assertion_count(session: Session, bucket: str, blob_prefix: str) -> None:
"""
Count the number of assertions to be exported and save the number to a file

:param session: the database session
:param bucket: the output GCP bucket name
:param blob_prefix: the directory prefix for the file to be saved
"""
output_filename = f'assertion.count'
assertion_count = get_assertion_count(session)
with open(output_filename, 'a') as outfile:
outfile.write(str(assertion_count[0]))
services.upload_to_gcp(bucket, output_filename, f'{blob_prefix}{output_filename}')
Loading