-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
carlopignatiello
committed
Dec 12, 2024
1 parent
7821c9a
commit 5e67bd9
Showing
5 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import sys | ||
import os | ||
import uuid | ||
|
||
import orjson | ||
from pyspark.sql.types import StructField, StructType, StringType | ||
|
||
from metrics.llm_metrics import LLMMetrics | ||
from utils.models import JobStatus | ||
from utils.db import update_job_status, write_to_db | ||
|
||
from pyspark.sql import SparkSession, DataFrame | ||
|
||
import logging | ||
|
||
|
||
def compute_metrics(df: DataFrame) -> dict: | ||
complete_record = {} | ||
metrics_service = LLMMetrics() | ||
model_quality = metrics_service.extract_metrics(df) | ||
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump()).decode( | ||
"utf-8" | ||
) | ||
return complete_record | ||
|
||
|
||
def main(spark_session: SparkSession, input_path: str, llm_uuid: str, table_name: str): | ||
spark_context = spark_session.sparkContext | ||
|
||
spark_context._jsc.hadoopConfiguration().set( | ||
"fs.s3a.access.key", os.getenv("AWS_ACCESS_KEY_ID") | ||
) | ||
spark_context._jsc.hadoopConfiguration().set( | ||
"fs.s3a.secret.key", os.getenv("AWS_SECRET_ACCESS_KEY") | ||
) | ||
spark_context._jsc.hadoopConfiguration().set( | ||
"fs.s3a.endpoint.region", os.getenv("AWS_REGION") | ||
) | ||
if os.getenv("S3_ENDPOINT_URL"): | ||
spark_context._jsc.hadoopConfiguration().set( | ||
"fs.s3a.endpoint", os.getenv("S3_ENDPOINT_URL") | ||
) | ||
spark_context._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true") | ||
spark_context._jsc.hadoopConfiguration().set( | ||
"fs.s3a.connection.ssl.enabled", "false" | ||
) | ||
df = spark_session.read.option("multiline", "true").json(input_path) | ||
complete_record = compute_metrics(df) | ||
|
||
complete_record.update({"UUID": str(uuid.uuid4()), "LLM_UUID": llm_uuid}) | ||
|
||
schema = StructType( | ||
[ | ||
StructField("UUID", StringType(), True), | ||
StructField("LLM_UUID", StringType(), True), | ||
StructField("MODEL_QUALITY", StringType(), True), | ||
] | ||
) | ||
|
||
write_to_db(spark_session, complete_record, schema, table_name) | ||
# # FIXME table name should come from parameters | ||
update_job_status(llm_uuid, JobStatus.SUCCEEDED, "llm_dataset") | ||
|
||
|
||
if __name__ == "__main__": | ||
spark_session = SparkSession.builder.appName( | ||
"radicalbit_reference_metrics" | ||
).getOrCreate() | ||
|
||
# Reference dataset s3 path is second param | ||
input_path = sys.argv[1] | ||
# Reference file uuid third param | ||
llm_uuid = sys.argv[2] | ||
# Table name fourth param | ||
table_name = sys.argv[3] | ||
|
||
try: | ||
main(spark_session, input_path, llm_uuid, table_name) | ||
|
||
except Exception as e: | ||
logging.exception(e) | ||
# FIXME table name should come from parameters | ||
update_job_status(llm_uuid, JobStatus.ERROR, "llm_dataset") | ||
finally: | ||
spark_session.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pyspark.sql.functions as F | ||
import numpy as np | ||
from pyspark.sql import DataFrame | ||
from pyspark.sql.types import FloatType | ||
from models.llm_dataset import LLMMetricsModel | ||
|
||
|
||
class LLMMetrics: | ||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
@F.udf(FloatType()) | ||
def compute_probability_udf(log_prob: float) -> float: | ||
return float(np.exp(log_prob)) | ||
|
||
@staticmethod | ||
@F.udf(FloatType()) | ||
def compute_perplexity(log_probs: list[float]) -> float: | ||
return float(np.exp(-np.mean(log_probs))) | ||
|
||
@staticmethod | ||
@F.udf(FloatType()) | ||
def compute_prob_mean_per_phrase(probs: list[float]) -> float: | ||
return float(np.mean(probs)) | ||
|
||
@staticmethod | ||
def remove_columns(df: DataFrame) -> DataFrame: | ||
df = df.drop( | ||
"model", | ||
"object", | ||
"created", | ||
"system_fingerprint", | ||
"usage", | ||
"service_tier", | ||
) | ||
return df | ||
|
||
def compute_prob(self, df: DataFrame): | ||
df = df.select(F.explode("choices").alias("element"), F.col("id")) | ||
df = df.select( | ||
F.col("id"), F.explode("element.logprobs.content").alias("content") | ||
) | ||
df = df.select("id", "content.logprob", "content.token").withColumn( | ||
"prob", self.compute_probability_udf("logprob") | ||
) | ||
return df | ||
|
||
def extract_metrics(self, df: DataFrame) -> LLMMetricsModel: | ||
df = self.remove_columns(df) | ||
df = self.compute_prob(df) | ||
df_prob = df.drop("logprob") | ||
df_mean_values = df.groupBy("id").agg( | ||
self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias( | ||
"prob_per_phrase" | ||
), | ||
self.compute_perplexity(F.collect_list("logprob")).alias( | ||
"perplex_per_phrase" | ||
), | ||
) | ||
df = df_mean_values.agg( | ||
F.mean("prob_per_phrase").alias("prob_tot_mean"), | ||
F.mean("perplex_per_phrase").alias("perplex_tot_mean"), | ||
) | ||
res = { | ||
"prob": df_prob.toPandas().to_dict(orient="records"), | ||
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"), | ||
"mean_per_file": df.toPandas().to_dict(orient="records"), | ||
} | ||
return LLMMetricsModel(**res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from pydantic import BaseModel, confloat | ||
from typing import List | ||
|
||
|
||
class Prob(BaseModel): | ||
id: str | ||
token: str | ||
prob: confloat(ge=0, le=1) | ||
|
||
|
||
class MeanPerPhrase(BaseModel): | ||
id: str | ||
prob_per_phrase: confloat(ge=0, le=1) | ||
perplex_per_phrase: confloat(ge=1) | ||
|
||
|
||
class MeanPerFile(BaseModel): | ||
prob_tot_mean: confloat(ge=0, le=1) | ||
perplex_tot_mean: confloat(ge=1) | ||
|
||
|
||
class LLMMetricsModel(BaseModel): | ||
prob: List[Prob] | ||
mean_per_phrase: List[MeanPerPhrase] | ||
mean_per_file: List[MeanPerFile] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
import orjson | ||
from jobs.llm_job import compute_metrics | ||
from jobs.metrics.llm_metrics import LLMMetrics | ||
from jobs.models.llm_dataset import LLMMetricsModel | ||
from tests.results.llm_metrics_results import llm_metric_results | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture | ||
def mock_response(spark_fixture, test_data_dir): | ||
yield spark_fixture.read.option("multiline", "true").json( | ||
f"{test_data_dir}/llm/metrics.json" | ||
) | ||
|
||
|
||
def test_remove_columns(spark_fixture, mock_response): | ||
llm_metrics_service = LLMMetrics() | ||
df = llm_metrics_service.remove_columns(mock_response) | ||
assert "id" in df.columns | ||
assert "choices" in df.columns | ||
assert len(df.columns) == 2 | ||
|
||
|
||
def test_compute_prob(spark_fixture, mock_response): | ||
llm_metrics_service = LLMMetrics() | ||
df = llm_metrics_service.remove_columns(mock_response) | ||
df = llm_metrics_service.compute_prob(df) | ||
assert {"id", "logprob", "token", "prob"} == set(df.columns) | ||
assert not df.rdd.isEmpty() | ||
|
||
|
||
def test_extract_metrics(spark_fixture, mock_response): | ||
llm_metrics_service = LLMMetrics() | ||
llm_metrics_model = llm_metrics_service.extract_metrics(mock_response) | ||
assert isinstance(llm_metrics_model, LLMMetricsModel) | ||
|
||
|
||
def test_compute_metrics(spark_fixture, mock_response): | ||
complete_record = compute_metrics(mock_response) | ||
model_quality = complete_record.get("MODEL_QUALITY") | ||
assert model_quality == orjson.dumps(llm_metric_results).decode("utf-8") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
llm_metric_results = { | ||
"prob": [ | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": "Sure", | ||
"prob": 0.541987419128418, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": ",", | ||
"prob": 0.9025230407714844, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": " go", | ||
"prob": 0.07877199351787567, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": " ahead", | ||
"prob": 0.9985936284065247, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": ".", | ||
"prob": 0.13000132143497467, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": " What's", | ||
"prob": 0.40887829661369324, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": " up", | ||
"prob": 0.9211180806159973, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"token": "?", | ||
"prob": 0.9999502301216125, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": "Certainly", | ||
"prob": 0.022015240043401718, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": "!", | ||
"prob": 0.8896080851554871, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": " Just", | ||
"prob": 0.0027362185064703226, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": " let", | ||
"prob": 0.5134729146957397, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": " me", | ||
"prob": 0.999944269657135, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": " know", | ||
"prob": 0.9991950988769531, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": " how", | ||
"prob": 0.49005329608917236, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"token": ".", | ||
"prob": 0.9663926959037781, | ||
}, | ||
], | ||
"mean_per_phrase": [ | ||
{ | ||
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", | ||
"prob_per_phrase": 0.6227279901504517, | ||
"perplex_per_phrase": 2.190884828567505, | ||
}, | ||
{ | ||
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", | ||
"prob_per_phrase": 0.61042720079422, | ||
"perplex_per_phrase": 4.080123424530029, | ||
}, | ||
], | ||
"mean_per_file": [ | ||
{"prob_tot_mean": 0.6165775954723358, "perplex_tot_mean": 3.135504126548767} | ||
], | ||
} |