From 60d7c9f5af281e8a2e9f5082b7d681068b18611f Mon Sep 17 00:00:00 2001 From: Daniele Tria Date: Thu, 19 Dec 2024 14:11:43 +0100 Subject: [PATCH] add content message field to model quality metrics --- api/app/models/metrics/model_quality_dto.py | 1 + api/tests/commons/db_mock.py | 1 + spark/jobs/metrics/completion_metrics.py | 17 ++++++++++++----- spark/jobs/models/completion_dataset.py | 1 + spark/tests/completion_metrics_test.py | 2 +- .../tests/results/completion_metrics_results.py | 2 ++ 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/api/app/models/metrics/model_quality_dto.py b/api/app/models/metrics/model_quality_dto.py index 3478a229..677de3af 100644 --- a/api/app/models/metrics/model_quality_dto.py +++ b/api/app/models/metrics/model_quality_dto.py @@ -199,6 +199,7 @@ class TokenProb(BaseModel): class TokenData(BaseModel): id: str + message_content: str probs: List[TokenProb] diff --git a/api/tests/commons/db_mock.py b/api/tests/commons/db_mock.py index b1d6188b..abe2d980 100644 --- a/api/tests/commons/db_mock.py +++ b/api/tests/commons/db_mock.py @@ -550,6 +550,7 @@ def get_sample_completion_dataset( 'tokens': [ { 'id': 'chatcmpl', + 'message_content': 'Sky is blue.', 'probs': [ {'prob': 0.27718424797058105, 'token': 'Sky'}, {'prob': 0.8951022028923035, 'token': ' is'}, diff --git a/spark/jobs/metrics/completion_metrics.py b/spark/jobs/metrics/completion_metrics.py index 51ab9c0f..415f583b 100644 --- a/spark/jobs/metrics/completion_metrics.py +++ b/spark/jobs/metrics/completion_metrics.py @@ -38,20 +38,25 @@ def remove_columns(df: DataFrame) -> DataFrame: return df def compute_prob(self, df: DataFrame): - df = df.select(F.explode("choices").alias("element"), F.col("id")) df = df.select( - F.col("id"), F.explode("element.logprobs.content").alias("content") + F.explode("choices").alias("element"), + F.col("id"), ) - df = df.select("id", "content.logprob", "content.token").withColumn( - "prob", self.compute_probability_udf("logprob") + df = df.select( + F.col("id"), + F.col("element.message.content").alias("message_content"), + F.explode("element.logprobs.content").alias("content"), ) + df = df.select( + "id", "message_content", "content.logprob", "content.token" + ).withColumn("prob", self.compute_probability_udf("logprob")) return df def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel: df = self.remove_columns(df) df = self.compute_prob(df) df_prob = df.drop("logprob") - df_prob = df_prob.groupBy("id").agg( + df_prob = df_prob.groupBy("id", "message_content").agg( F.collect_list(F.struct("token", "prob")).alias("probs") ) df_mean_values = df.groupBy("id").agg( @@ -66,9 +71,11 @@ def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel: F.mean("prob_per_phrase").alias("prob_tot_mean"), F.mean("perplex_per_phrase").alias("perplex_tot_mean"), ) + df_prob = df_prob.orderBy("id") tokens = [ { "id": row["id"], + "message_content": row["message_content"], "probs": [ {"token": prob["token"], "prob": prob["prob"]} for prob in row["probs"] diff --git a/spark/jobs/models/completion_dataset.py b/spark/jobs/models/completion_dataset.py index 3169b1de..420c5ff6 100644 --- a/spark/jobs/models/completion_dataset.py +++ b/spark/jobs/models/completion_dataset.py @@ -9,6 +9,7 @@ class Prob(BaseModel): class Probs(BaseModel): id: str + message_content: str probs: List[Prob] model_config = ConfigDict(ser_json_inf_nan="null") diff --git a/spark/tests/completion_metrics_test.py b/spark/tests/completion_metrics_test.py index 19dcb446..7eb387c6 100644 --- a/spark/tests/completion_metrics_test.py +++ b/spark/tests/completion_metrics_test.py @@ -25,7 +25,7 @@ def test_compute_prob(spark_fixture, input_file): completion_metrics_service = CompletionMetrics() df = completion_metrics_service.remove_columns(input_file) df = completion_metrics_service.compute_prob(df) - assert {"id", "logprob", "token", "prob"} == set(df.columns) + assert {"id", "logprob", "message_content", "token", "prob"} == set(df.columns) assert not df.rdd.isEmpty() diff --git a/spark/tests/results/completion_metrics_results.py b/spark/tests/results/completion_metrics_results.py index b98aa14b..51bac88b 100644 --- a/spark/tests/results/completion_metrics_results.py +++ b/spark/tests/results/completion_metrics_results.py @@ -2,6 +2,7 @@ "tokens": [ { "id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", + "message_content": "Sure, go ahead. What's up?", "probs": [ {"token": "Sure", "prob": 0.541987419128418}, {"token": ",", "prob": 0.9025230407714844}, @@ -15,6 +16,7 @@ }, { "id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", + "message_content": "Certainly! Just let me know how.", "probs": [ {"token": "Certainly", "prob": 0.022015240043401718}, {"token": "!", "prob": 0.8896080851554871},