Skip to content

Commit

Permalink
llm_metrics job with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlopignatiello committed Dec 12, 2024
1 parent 7821c9a commit 5e67bd9
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 0 deletions.
85 changes: 85 additions & 0 deletions spark/jobs/llm_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
import os
import uuid

import orjson
from pyspark.sql.types import StructField, StructType, StringType

from metrics.llm_metrics import LLMMetrics
from utils.models import JobStatus
from utils.db import update_job_status, write_to_db

from pyspark.sql import SparkSession, DataFrame

import logging


def compute_metrics(df: DataFrame) -> dict:
complete_record = {}
metrics_service = LLMMetrics()
model_quality = metrics_service.extract_metrics(df)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump()).decode(
"utf-8"
)
return complete_record


def main(spark_session: SparkSession, input_path: str, llm_uuid: str, table_name: str):
spark_context = spark_session.sparkContext

spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY")
)
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint.region", os.getenv("AWS_REGION")
)
if os.getenv("S3_ENDPOINT_URL"):
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL")
)
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true")
spark_context._jsc.hadoopConfiguration().set(
"fs.s3a.connection.ssl.enabled", "false"
)
df = spark_session.read.option("multiline", "true").json(input_path)
complete_record = compute_metrics(df)

complete_record.update({"UUID": str(uuid.uuid4()), "LLM_UUID": llm_uuid})

schema = StructType(
[
StructField("UUID", StringType(), True),
StructField("LLM_UUID", StringType(), True),
StructField("MODEL_QUALITY", StringType(), True),
]
)

write_to_db(spark_session, complete_record, schema, table_name)
# # FIXME table name should come from parameters
update_job_status(llm_uuid, JobStatus.SUCCEEDED, "llm_dataset")


if __name__ == "__main__":
spark_session = SparkSession.builder.appName(
"radicalbit_reference_metrics"
).getOrCreate()

# Reference dataset s3 path is second param
input_path = sys.argv[1]
# Reference file uuid third param
llm_uuid = sys.argv[2]
# Table name fourth param
table_name = sys.argv[3]

try:
main(spark_session, input_path, llm_uuid, table_name)

except Exception as e:
logging.exception(e)
# FIXME table name should come from parameters
update_job_status(llm_uuid, JobStatus.ERROR, "llm_dataset")
finally:
spark_session.stop()
70 changes: 70 additions & 0 deletions spark/jobs/metrics/llm_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pyspark.sql.functions as F
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType
from models.llm_dataset import LLMMetricsModel


class LLMMetrics:
def __init__(self):
pass

@staticmethod
@F.udf(FloatType())
def compute_probability_udf(log_prob: float) -> float:
return float(np.exp(log_prob))

@staticmethod
@F.udf(FloatType())
def compute_perplexity(log_probs: list[float]) -> float:
return float(np.exp(-np.mean(log_probs)))

@staticmethod
@F.udf(FloatType())
def compute_prob_mean_per_phrase(probs: list[float]) -> float:
return float(np.mean(probs))

@staticmethod
def remove_columns(df: DataFrame) -> DataFrame:
df = df.drop(
"model",
"object",
"created",
"system_fingerprint",
"usage",
"service_tier",
)
return df

def compute_prob(self, df: DataFrame):
df = df.select(F.explode("choices").alias("element"), F.col("id"))
df = df.select(
F.col("id"), F.explode("element.logprobs.content").alias("content")
)
df = df.select("id", "content.logprob", "content.token").withColumn(
"prob", self.compute_probability_udf("logprob")
)
return df

def extract_metrics(self, df: DataFrame) -> LLMMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_mean_values = df.groupBy("id").agg(
self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias(
"prob_per_phrase"
),
self.compute_perplexity(F.collect_list("logprob")).alias(
"perplex_per_phrase"
),
)
df = df_mean_values.agg(
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
res = {
"prob": df_prob.toPandas().to_dict(orient="records"),
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"),
"mean_per_file": df.toPandas().to_dict(orient="records"),
}
return LLMMetricsModel(**res)
25 changes: 25 additions & 0 deletions spark/jobs/models/llm_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel, confloat
from typing import List


class Prob(BaseModel):
id: str
token: str
prob: confloat(ge=0, le=1)


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: confloat(ge=0, le=1)
perplex_per_phrase: confloat(ge=1)


class MeanPerFile(BaseModel):
prob_tot_mean: confloat(ge=0, le=1)
perplex_tot_mean: confloat(ge=1)


class LLMMetricsModel(BaseModel):
prob: List[Prob]
mean_per_phrase: List[MeanPerPhrase]
mean_per_file: List[MeanPerFile]
44 changes: 44 additions & 0 deletions spark/tests/llm_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import orjson
from jobs.llm_job import compute_metrics
from jobs.metrics.llm_metrics import LLMMetrics
from jobs.models.llm_dataset import LLMMetricsModel
from tests.results.llm_metrics_results import llm_metric_results
import logging

logger = logging.getLogger(__name__)


@pytest.fixture
def mock_response(spark_fixture, test_data_dir):
yield spark_fixture.read.option("multiline", "true").json(
f"{test_data_dir}/llm/metrics.json"
)


def test_remove_columns(spark_fixture, mock_response):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(mock_response)
assert "id" in df.columns
assert "choices" in df.columns
assert len(df.columns) == 2


def test_compute_prob(spark_fixture, mock_response):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(mock_response)
df = llm_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


def test_extract_metrics(spark_fixture, mock_response):
llm_metrics_service = LLMMetrics()
llm_metrics_model = llm_metrics_service.extract_metrics(mock_response)
assert isinstance(llm_metrics_model, LLMMetricsModel)


def test_compute_metrics(spark_fixture, mock_response):
complete_record = compute_metrics(mock_response)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == orjson.dumps(llm_metric_results).decode("utf-8")
99 changes: 99 additions & 0 deletions spark/tests/results/llm_metrics_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
llm_metric_results = {
"prob": [
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": "Sure",
"prob": 0.541987419128418,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": ",",
"prob": 0.9025230407714844,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": " go",
"prob": 0.07877199351787567,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": " ahead",
"prob": 0.9985936284065247,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": ".",
"prob": 0.13000132143497467,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": " What's",
"prob": 0.40887829661369324,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": " up",
"prob": 0.9211180806159973,
},
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"token": "?",
"prob": 0.9999502301216125,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": "Certainly",
"prob": 0.022015240043401718,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": "!",
"prob": 0.8896080851554871,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": " Just",
"prob": 0.0027362185064703226,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": " let",
"prob": 0.5134729146957397,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": " me",
"prob": 0.999944269657135,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": " know",
"prob": 0.9991950988769531,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": " how",
"prob": 0.49005329608917236,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"token": ".",
"prob": 0.9663926959037781,
},
],
"mean_per_phrase": [
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"prob_per_phrase": 0.6227279901504517,
"perplex_per_phrase": 2.190884828567505,
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"prob_per_phrase": 0.61042720079422,
"perplex_per_phrase": 4.080123424530029,
},
],
"mean_per_file": [
{"prob_tot_mean": 0.6165775954723358, "perplex_tot_mean": 3.135504126548767}
],
}

0 comments on commit 5e67bd9

Please sign in to comment.