Skip to content

Commit

Permalink
feat: add content message to model quality text generation (#220)
Browse files Browse the repository at this point in the history
* add content message field to model quality metrics

* add content message field to model quality metrics
  • Loading branch information
dtria91 authored Dec 19, 2024
1 parent baadd9e commit 50832ac
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 7 deletions.
1 change: 1 addition & 0 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class TokenProb(BaseModel):

class TokenData(BaseModel):
id: str
message_content: str
probs: List[TokenProb]


Expand Down
1 change: 1 addition & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def get_sample_completion_dataset(
'tokens': [
{
'id': 'chatcmpl',
'message_content': 'Sky is blue.',
'probs': [
{'prob': 0.27718424797058105, 'token': 'Sky'},
{'prob': 0.8951022028923035, 'token': ' is'},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class TokenProb(BaseModel):

class TokenData(BaseModel):
id: str
message_content: str
probs: List[TokenProb]


Expand Down
4 changes: 3 additions & 1 deletion sdk/tests/apis/model_completion_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_text_generation_model_quality_ok(self):
{
"tokens": [
{
"id":"chatcmpl",
"id": "chatcmpl",
"message_content": "Sky is blue.",
"probs":[
{
"prob":0.27,
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_text_generation_model_quality_ok(self):
metrics = model_completion_dataset.model_quality()

assert isinstance(metrics, CompletionTextGenerationModelQuality)
assert metrics.tokens[0].message_content == 'Sky is blue.'
assert metrics.tokens[0].probs[0].prob == 0.27
assert metrics.tokens[0].probs[0].token == 'Sky'
assert metrics.mean_per_file[0].prob_tot_mean == 0.71
Expand Down
17 changes: 12 additions & 5 deletions spark/jobs/metrics/completion_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,25 @@ def remove_columns(df: DataFrame) -> DataFrame:
return df

def compute_prob(self, df: DataFrame):
df = df.select(F.explode("choices").alias("element"), F.col("id"))
df = df.select(
F.col("id"), F.explode("element.logprobs.content").alias("content")
F.explode("choices").alias("element"),
F.col("id"),
)
df = df.select("id", "content.logprob", "content.token").withColumn(
"prob", self.compute_probability_udf("logprob")
df = df.select(
F.col("id"),
F.col("element.message.content").alias("message_content"),
F.explode("element.logprobs.content").alias("content"),
)
df = df.select(
"id", "message_content", "content.logprob", "content.token"
).withColumn("prob", self.compute_probability_udf("logprob"))
return df

def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_prob = df_prob.groupBy("id").agg(
df_prob = df_prob.groupBy("id", "message_content").agg(
F.collect_list(F.struct("token", "prob")).alias("probs")
)
df_mean_values = df.groupBy("id").agg(
Expand All @@ -66,9 +71,11 @@ def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
df_prob = df_prob.orderBy("id")
tokens = [
{
"id": row["id"],
"message_content": row["message_content"],
"probs": [
{"token": prob["token"], "prob": prob["prob"]}
for prob in row["probs"]
Expand Down
1 change: 1 addition & 0 deletions spark/jobs/models/completion_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Prob(BaseModel):

class Probs(BaseModel):
id: str
message_content: str
probs: List[Prob]

model_config = ConfigDict(ser_json_inf_nan="null")
Expand Down
2 changes: 1 addition & 1 deletion spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_compute_prob(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
df = completion_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert {"id", "logprob", "message_content", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


Expand Down
2 changes: 2 additions & 0 deletions spark/tests/results/completion_metrics_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"tokens": [
{
"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN",
"message_content": "Sure, go ahead. What's up?",
"probs": [
{"token": "Sure", "prob": 0.541987419128418},
{"token": ",", "prob": 0.9025230407714844},
Expand All @@ -15,6 +16,7 @@
},
{
"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG",
"message_content": "Certainly! Just let me know how.",
"probs": [
{"token": "Certainly", "prob": 0.022015240043401718},
{"token": "!", "prob": 0.8896080851554871},
Expand Down

0 comments on commit 50832ac

Please sign in to comment.