diff --git a/api/app/services/file_service.py b/api/app/services/file_service.py index 3558baae..e08756e4 100644 --- a/api/app/services/file_service.py +++ b/api/app/services/file_service.py @@ -126,6 +126,7 @@ def upload_reference_file( path.replace('s3://', 's3a://'), str(inserted_file.uuid), ReferenceDatasetMetrics.__tablename__, + ReferenceDataset.__tablename__, ], ) @@ -175,6 +176,7 @@ def bind_reference_file( file_ref.file_url.replace('s3://', 's3a://'), str(inserted_file.uuid), ReferenceDatasetMetrics.__tablename__, + ReferenceDataset.__tablename__, ], ) @@ -259,6 +261,7 @@ def upload_current_file( str(inserted_file.uuid), reference_dataset.path.replace('s3://', 's3a://'), CurrentDatasetMetrics.__tablename__, + CurrentDataset.__tablename__, ], ) @@ -312,6 +315,7 @@ def bind_current_file( str(inserted_file.uuid), reference_dataset.path.replace('s3://', 's3a://'), CurrentDatasetMetrics.__tablename__, + CurrentDataset.__tablename__, ], ) @@ -377,10 +381,10 @@ def upload_completion_file( app_name=str(model_out.uuid), app_path=spark_config.spark_completion_app_path, app_arguments=[ - model_out.model_dump_json(), path.replace('s3://', 's3a://'), str(inserted_file.uuid), CompletionDatasetMetrics.__tablename__, + CompletionDataset.__tablename__, ], ) diff --git a/spark/jobs/completion_job.py b/spark/jobs/completion_job.py new file mode 100644 index 00000000..d601c223 --- /dev/null +++ b/spark/jobs/completion_job.py @@ -0,0 +1,88 @@ +import sys +import os +import uuid + +from pyspark.sql.types import StructField, StructType, StringType +from utils.models import JobStatus +from utils.db import update_job_status, write_to_db + +from pyspark.sql import SparkSession, DataFrame + +import logging + + +def compute_metrics(df: DataFrame) -> dict: + complete_record = {} + # TODO: compute model quality metrics + return complete_record + + +def main( + spark_session: SparkSession, + completion_dataset_path: str, + completion_uuid: str, + metrics_table_name: str, + dataset_table_name: str, +): + spark_context = spark_session.sparkContext + + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint.region", os.getenv("AWS_REGION") + ) + if os.getenv("S3_ENDPOINT_URL"): + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL") + ) + spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true") + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.connection.ssl.enabled", "false" + ) + df = spark_session.read.option("multiline", "true").json(completion_dataset_path) + complete_record = compute_metrics(df) + + complete_record.update( + {"UUID": str(uuid.uuid4()), "COMPLETION_UUID": completion_uuid} + ) + + schema = StructType( + [ + StructField("UUID", StringType(), True), + StructField("COMPLETION_UUID", StringType(), True), + StructField("MODEL_QUALITY", StringType(), True), + ] + ) + + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(completion_uuid, JobStatus.SUCCEEDED, dataset_table_name) + + +if __name__ == "__main__": + spark_session = SparkSession.builder.appName( + "radicalbit_completion_metrics" + ).getOrCreate() + + completion_dataset_path = sys.argv[1] + completion_uuid = sys.argv[2] + metrics_table_name = sys.argv[3] + dataset_table_name = sys.argv[4] + + try: + main( + spark_session, + completion_dataset_path, + completion_uuid, + metrics_table_name, + dataset_table_name, + ) + + except Exception as e: + logging.exception(e) + update_job_status(completion_uuid, JobStatus.ERROR, dataset_table_name) + finally: + spark_session.stop() diff --git a/spark/jobs/current_job.py b/spark/jobs/current_job.py index 02b37d73..ff0140da 100644 --- a/spark/jobs/current_job.py +++ b/spark/jobs/current_job.py @@ -124,7 +124,8 @@ def main( current_dataset_path: str, current_uuid: str, reference_dataset_path: str, - table_name: str, + metrics_table_name: str, + dataset_table_name: str, ): spark_context = spark_session.sparkContext @@ -171,9 +172,8 @@ def main( ] ) - write_to_db(spark_session, complete_record, schema, table_name) - # FIXME table name should come from parameters - update_job_status(current_uuid, JobStatus.SUCCEEDED, "current_dataset") + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(current_uuid, JobStatus.SUCCEEDED, dataset_table_name) if __name__ == "__main__": @@ -189,8 +189,10 @@ def main( current_uuid = sys.argv[3] # Reference dataset s3 path is fourth param reference_dataset_path = sys.argv[4] - # Table name fifth param - table_name = sys.argv[5] + # Metrics Table name fifth param + metrics_table_name = sys.argv[5] + # Metrics Table name sixth param + dataset_table_name = sys.argv[6] try: main( @@ -199,11 +201,10 @@ def main( current_dataset_path, current_uuid, reference_dataset_path, - table_name, + metrics_table_name, ) except Exception as e: logging.exception(e) - # FIXME table name should come from parameters - update_job_status(current_uuid, JobStatus.ERROR, "current_dataset") + update_job_status(current_uuid, JobStatus.ERROR, dataset_table_name) finally: spark_session.stop() diff --git a/spark/jobs/reference_job.py b/spark/jobs/reference_job.py index 59a5993a..4750ab3b 100644 --- a/spark/jobs/reference_job.py +++ b/spark/jobs/reference_job.py @@ -77,7 +77,8 @@ def main( model: ModelOut, reference_dataset_path: str, reference_uuid: str, - table_name: str, + metrics_table_name: str, + dataset_table_name: str, ): spark_context = spark_session.sparkContext @@ -118,9 +119,8 @@ def main( ] ) - write_to_db(spark_session, complete_record, schema, table_name) - # FIXME table name should come from parameters - update_job_status(reference_uuid, JobStatus.SUCCEEDED, "reference_dataset") + write_to_db(spark_session, complete_record, schema, metrics_table_name) + update_job_status(reference_uuid, JobStatus.SUCCEEDED, dataset_table_name) if __name__ == "__main__": @@ -134,14 +134,21 @@ def main( reference_dataset_path = sys.argv[2] # Reference file uuid third param reference_uuid = sys.argv[3] - # Table name fourth param - table_name = sys.argv[4] + # Metrics table name fourth param + metrics_table_name = sys.argv[4] + # Dataset table name fourth param + dataset_table_name = sys.argv[5] try: - main(spark_session, model, reference_dataset_path, reference_uuid, table_name) + main( + spark_session, + model, + reference_dataset_path, + reference_uuid, + metrics_table_name, + ) except Exception as e: logging.exception(e) - # FIXME table name should come from parameters - update_job_status(reference_uuid, JobStatus.ERROR, "reference_dataset") + update_job_status(reference_uuid, JobStatus.ERROR, dataset_table_name) finally: spark_session.stop() diff --git a/spark/jobs/utils/db.py b/spark/jobs/utils/db.py index 495cd509..ae146cf9 100644 --- a/spark/jobs/utils/db.py +++ b/spark/jobs/utils/db.py @@ -17,7 +17,7 @@ url = f"jdbc:postgresql://{db_host}:{db_port}/{db_name}" -def update_job_status(file_uuid: str, status: str, table_name: str): +def update_job_status(file_uuid: str, status: str, dataset_table_name: str): # Use psycopg2 to update the job status with psycopg2.connect( host=db_host, @@ -30,7 +30,7 @@ def update_job_status(file_uuid: str, status: str, table_name: str): with conn.cursor() as cur: cur.execute( f""" - UPDATE {table_name} + UPDATE {dataset_table_name} SET "STATUS" = %s WHERE "UUID" = %s """, @@ -40,7 +40,10 @@ def update_job_status(file_uuid: str, status: str, table_name: str): def write_to_db( - spark_session: SparkSession, record: Dict, schema: StructType, table_name: str + spark_session: SparkSession, + record: Dict, + schema: StructType, + metrics_table_name: str, ): out_df = spark_session.createDataFrame(data=[record], schema=schema) @@ -49,4 +52,6 @@ def write_to_db( "stringtype", "unspecified" ).option("driver", "org.postgresql.Driver").option("user", user).option( "password", password - ).option("dbtable", f'"{postgres_schema}"."{table_name}"').mode("append").save() + ).option("dbtable", f'"{postgres_schema}"."{metrics_table_name}"').mode( + "append" + ).save()