diff --git a/spark/jobs/llm_job.py b/spark/jobs/llm_job.py new file mode 100644 index 00000000..e24a582c --- /dev/null +++ b/spark/jobs/llm_job.py @@ -0,0 +1,85 @@ +import sys +import os +import uuid + +import orjson +from pyspark.sql.types import StructField, StructType, StringType + +from metrics.llm_metrics import LLMMetrics +from utils.models import JobStatus +from utils.db import update_job_status, write_to_db + +from pyspark.sql import SparkSession, DataFrame + +import logging + + +def compute_metrics(df: DataFrame) -> dict: + complete_record = {} + metrics_service = LLMMetrics() + model_quality = metrics_service.extract_metrics(df) + complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump()).decode( + "utf-8" + ) + return complete_record + + +def main(spark_session: SparkSession, input_path: str, llm_uuid: str, table_name: str): + spark_context = spark_session.sparkContext + + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY") + ) + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint.region", os.getenv("AWS_REGION") + ) + if os.getenv("S3_ENDPOINT_URL"): + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL") + ) + spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true") + spark_context._jsc.hadoopConfiguration().set( + "fs.s3a.connection.ssl.enabled", "false" + ) + df = spark_session.read.option("multiline", "true").json(input_path) + complete_record = compute_metrics(df) + + complete_record.update({"UUID": str(uuid.uuid4()), "LLM_UUID": llm_uuid}) + + schema = StructType( + [ + StructField("UUID", StringType(), True), + StructField("LLM_UUID", StringType(), True), + StructField("MODEL_QUALITY", StringType(), True), + ] + ) + + write_to_db(spark_session, complete_record, schema, table_name) + # # FIXME table name should come from parameters + update_job_status(llm_uuid, JobStatus.SUCCEEDED, "llm_dataset") + + +if __name__ == "__main__": + spark_session = SparkSession.builder.appName( + "radicalbit_reference_metrics" + ).getOrCreate() + + # Reference dataset s3 path is second param + input_path = sys.argv[1] + # Reference file uuid third param + llm_uuid = sys.argv[2] + # Table name fourth param + table_name = sys.argv[3] + + try: + main(spark_session, input_path, llm_uuid, table_name) + + except Exception as e: + logging.exception(e) + # FIXME table name should come from parameters + update_job_status(llm_uuid, JobStatus.ERROR, "llm_dataset") + finally: + spark_session.stop() diff --git a/spark/jobs/metrics/llm_metrics.py b/spark/jobs/metrics/llm_metrics.py new file mode 100644 index 00000000..7950a129 --- /dev/null +++ b/spark/jobs/metrics/llm_metrics.py @@ -0,0 +1,70 @@ +import pyspark.sql.functions as F +import numpy as np +from pyspark.sql import DataFrame +from pyspark.sql.types import FloatType +from models.llm_dataset import LLMMetricsModel + + +class LLMMetrics: + def __init__(self): + pass + + @staticmethod + @F.udf(FloatType()) + def compute_probability_udf(log_prob: float) -> float: + return float(np.exp(log_prob)) + + @staticmethod + @F.udf(FloatType()) + def compute_perplexity(log_probs: list[float]) -> float: + return float(np.exp(-np.mean(log_probs))) + + @staticmethod + @F.udf(FloatType()) + def compute_prob_mean_per_phrase(probs: list[float]) -> float: + return float(np.mean(probs)) + + @staticmethod + def remove_columns(df: DataFrame) -> DataFrame: + df = df.drop( + "model", + "object", + "created", + "system_fingerprint", + "usage", + "service_tier", + ) + return df + + def compute_prob(self, df: DataFrame): + df = df.select(F.explode("choices").alias("element"), F.col("id")) + df = df.select( + F.col("id"), F.explode("element.logprobs.content").alias("content") + ) + df = df.select("id", "content.logprob", "content.token").withColumn( + "prob", self.compute_probability_udf("logprob") + ) + return df + + def extract_metrics(self, df: DataFrame) -> LLMMetricsModel: + df = self.remove_columns(df) + df = self.compute_prob(df) + df_prob = df.drop("logprob") + df_mean_values = df.groupBy("id").agg( + self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias( + "prob_per_phrase" + ), + self.compute_perplexity(F.collect_list("logprob")).alias( + "perplex_per_phrase" + ), + ) + df = df_mean_values.agg( + F.mean("prob_per_phrase").alias("prob_tot_mean"), + F.mean("perplex_per_phrase").alias("perplex_tot_mean"), + ) + res = { + "prob": df_prob.toPandas().to_dict(orient="records"), + "mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"), + "mean_per_file": df.toPandas().to_dict(orient="records"), + } + return LLMMetricsModel(**res) diff --git a/spark/jobs/models/llm_dataset.py b/spark/jobs/models/llm_dataset.py new file mode 100644 index 00000000..2e2618a0 --- /dev/null +++ b/spark/jobs/models/llm_dataset.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, confloat +from typing import List + + +class Prob(BaseModel): + id: str + token: str + prob: confloat(ge=0, le=1) + + +class MeanPerPhrase(BaseModel): + id: str + prob_per_phrase: confloat(ge=0, le=1) + perplex_per_phrase: confloat(ge=1) + + +class MeanPerFile(BaseModel): + prob_tot_mean: confloat(ge=0, le=1) + perplex_tot_mean: confloat(ge=1) + + +class LLMMetricsModel(BaseModel): + prob: List[Prob] + mean_per_phrase: List[MeanPerPhrase] + mean_per_file: List[MeanPerFile] diff --git a/spark/tests/llm_metrics_test.py b/spark/tests/llm_metrics_test.py new file mode 100644 index 00000000..a6ce19a0 --- /dev/null +++ b/spark/tests/llm_metrics_test.py @@ -0,0 +1,44 @@ +import pytest +import orjson +from jobs.llm_job import compute_metrics +from jobs.metrics.llm_metrics import LLMMetrics +from jobs.models.llm_dataset import LLMMetricsModel +from tests.results.llm_metrics_results import llm_metric_results +import logging + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mock_response(spark_fixture, test_data_dir): + yield spark_fixture.read.option("multiline", "true").json( + f"{test_data_dir}/llm/metrics.json" + ) + + +def test_remove_columns(spark_fixture, mock_response): + llm_metrics_service = LLMMetrics() + df = llm_metrics_service.remove_columns(mock_response) + assert "id" in df.columns + assert "choices" in df.columns + assert len(df.columns) == 2 + + +def test_compute_prob(spark_fixture, mock_response): + llm_metrics_service = LLMMetrics() + df = llm_metrics_service.remove_columns(mock_response) + df = llm_metrics_service.compute_prob(df) + assert {"id", "logprob", "token", "prob"} == set(df.columns) + assert not df.rdd.isEmpty() + + +def test_extract_metrics(spark_fixture, mock_response): + llm_metrics_service = LLMMetrics() + llm_metrics_model = llm_metrics_service.extract_metrics(mock_response) + assert isinstance(llm_metrics_model, LLMMetricsModel) + + +def test_compute_metrics(spark_fixture, mock_response): + complete_record = compute_metrics(mock_response) + model_quality = complete_record.get("MODEL_QUALITY") + assert model_quality == orjson.dumps(llm_metric_results).decode("utf-8") diff --git a/spark/tests/results/llm_metrics_results.py b/spark/tests/results/llm_metrics_results.py new file mode 100644 index 00000000..1f236f66 --- /dev/null +++ b/spark/tests/results/llm_metrics_results.py @@ -0,0 +1,99 @@ +llm_metric_results = { + "prob": [ + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": "Sure", + "prob": 0.541987419128418, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": ",", + "prob": 0.9025230407714844, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": " go", + "prob": 0.07877199351787567, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": " ahead", + "prob": 0.9985936284065247, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": ".", + "prob": 0.13000132143497467, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": " What's", + "prob": 0.40887829661369324, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": " up", + "prob": 0.9211180806159973, + }, + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "token": "?", + "prob": 0.9999502301216125, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": "Certainly", + "prob": 0.022015240043401718, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": "!", + "prob": 0.8896080851554871, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": " Just", + "prob": 0.0027362185064703226, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": " let", + "prob": 0.5134729146957397, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": " me", + "prob": 0.999944269657135, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": " know", + "prob": 0.9991950988769531, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": " how", + "prob": 0.49005329608917236, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "token": ".", + "prob": 0.9663926959037781, + }, + ], + "mean_per_phrase": [ + { + "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "prob_per_phrase": 0.6227279901504517, + "perplex_per_phrase": 2.190884828567505, + }, + { + "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "prob_per_phrase": 0.61042720079422, + "perplex_per_phrase": 4.080123424530029, + }, + ], + "mean_per_file": [ + {"prob_tot_mean": 0.6165775954723358, "perplex_tot_mean": 3.135504126548767} + ], +}