Skip to content

Commit

Permalink
feat(spark): add completion spark job
Browse files Browse the repository at this point in the history
  • Loading branch information
carlopignatiello committed Dec 12, 2024
1 parent 5e67bd9 commit 088291b
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 233 deletions.
85 changes: 0 additions & 85 deletions spark/jobs/llm_job.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType
from models.llm_dataset import LLMMetricsModel
from models.completion_dataset import LLMMetricsModel


class LLMMetrics:
Expand Down Expand Up @@ -50,6 +50,9 @@ def extract_metrics(self, df: DataFrame) -> LLMMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_prob = df_prob.groupBy("id").agg(
F.collect_list(F.struct("token", "prob")).alias("probs")
)
df_mean_values = df.groupBy("id").agg(
self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias(
"prob_per_phrase"
Expand All @@ -62,8 +65,19 @@ def extract_metrics(self, df: DataFrame) -> LLMMetricsModel:
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
tokens = [
{
"id": row["id"],
"probs": [
{"token": prob["token"], "prob": prob["prob"]}
for prob in row["probs"]
],
}
for row in df_prob.toLocalIterator()
]

res = {
"prob": df_prob.toPandas().to_dict(orient="records"),
"tokens": tokens,
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"),
"mean_per_file": df.toPandas().to_dict(orient="records"),
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
from pydantic import BaseModel, confloat
from pydantic import BaseModel, confloat, ConfigDict
from typing import List


class Prob(BaseModel):
id: str
token: str
prob: confloat(ge=0, le=1)


class Probs(BaseModel):
id: str
probs: List[Prob]

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: confloat(ge=0, le=1)
perplex_per_phrase: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerFile(BaseModel):
prob_tot_mean: confloat(ge=0, le=1)
perplex_tot_mean: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class LLMMetricsModel(BaseModel):
prob: List[Prob]
tokens: List[Probs]
mean_per_phrase: List[MeanPerPhrase]
mean_per_file: List[MeanPerFile]
42 changes: 42 additions & 0 deletions spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from jobs.completion_job import compute_metrics
from jobs.metrics.completion_metrics import LLMMetrics
from jobs.models.completion_dataset import LLMMetricsModel
from tests.results.completion_metrics_results import completion_metric_results


@pytest.fixture
def input_file(spark_fixture, test_data_dir):
yield spark_fixture.read.option("multiline", "true").json(
f"{test_data_dir}/completion/metrics.json"
)


def test_remove_columns(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(input_file)
assert "id" in df.columns
assert "choices" in df.columns
assert len(df.columns) == 2


def test_compute_prob(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(input_file)
df = llm_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


def test_extract_metrics(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
llm_metrics_model: LLMMetricsModel = llm_metrics_service.extract_metrics(input_file)
assert len(llm_metrics_model.tokens) > 0
assert len(llm_metrics_model.mean_per_phrase) > 0
assert len(llm_metrics_model.mean_per_file) > 0


def test_compute_metrics(spark_fixture, input_file):
complete_record = compute_metrics(input_file)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == completion_metric_results
44 changes: 0 additions & 44 deletions spark/tests/llm_metrics_test.py

This file was deleted.

4 changes: 4 additions & 0 deletions spark/tests/resources/completion/metrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.61251247, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.102561064, "top_logprobs": []}, {"token": " go", "bytes": [32, 103, 111], "logprob": -2.5411978, "top_logprobs": []}, {"token": " ahead", "bytes": [32, 97, 104, 101, 97, 100], "logprob": -0.0014073749, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -2.0402107, "top_logprobs": []}, {"token": " What's", "bytes": [32, 87, 104, 97, 116, 39, 115], "logprob": -0.8943377, "top_logprobs": []}, {"token": " up", "bytes": [32, 117, 112], "logprob": -0.08216706, "top_logprobs": []}, {"token": "?", "bytes": [63], "logprob": -4.978234e-05, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure, go ahead. What's up?", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733743961, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Certainly", "bytes": [67, 101, 114, 116, 97, 105, 110, 108, 121], "logprob": -3.8160203, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.11697425, "top_logprobs": []}, {"token": " Just", "bytes": [32, 74, 117, 115, 116], "logprob": -5.9011784, "top_logprobs": []}, {"token": " let", "bytes": [32, 108, 101, 116], "logprob": -0.666558, "top_logprobs": []}, {"token": " me", "bytes": [32, 109, 101], "logprob": -5.574252e-05, "top_logprobs": []}, {"token": " know", "bytes": [32, 107, 110, 111, 119], "logprob": -0.0008052219, "top_logprobs": []}, {"token": " how", "bytes": [32, 104, 111, 119], "logprob": -0.7132411, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.034184996, "top_logprobs": []}], "refusal": null}, "message": {"content": "Certainly! Just let me know how.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733751906, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
5 changes: 5 additions & 0 deletions spark/tests/resources/completion/metrics_one.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{"id": "chatcmpl-AcYPBK1t4QUvRG1oiS2JjoK0xdrGH", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.46845868, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.82938546, "top_logprobs": []}, {"token": " Please", "bytes": [32, 80, 108, 101, 97, 115, 101], "logprob": -2.1082883, "top_logprobs": []}, {"token": " ask", "bytes": [32, 97, 115, 107], "logprob": -0.67591065, "top_logprobs": []}, {"token": " your", "bytes": [32, 121, 111, 117, 114], "logprob": -0.022492433, "top_logprobs": []}, {"token": " question", "bytes": [32, 113, 117, 101, 115, 116, 105, 111, 110], "logprob": -0.13909559, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -3.3818815, "top_logprobs": []}, {"token": " I'll", "bytes": [32, 73, 39, 108, 108], "logprob": -0.4059926, "top_logprobs": []}, {"token": " be", "bytes": [32, 98, 101], "logprob": -1.799196, "top_logprobs": []}, {"token": " concise", "bytes": [32, 99, 111, 110, 99, 105, 115, 101], "logprob": -0.23712571, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.03838387, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure! Please ask your question. I'll be concise.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752081, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "dwadawdwd", "usage": {"completion_tokens": 11, "prompt_tokens": 45, "total_tokens": 56, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYQfMsRAIA01ynfZAQb8Zmyc6WKp", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -5.5577775e-06, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -5.2285613e-05, "top_logprobs": []}, {"token": " Python", "bytes": [32, 80, 121, 116, 104, 111, 110], "logprob": -0.00089550123, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -4.246537e-06, "top_logprobs": []}, {"token": " versatile", "bytes": [32, 118, 101, 114, 115, 97, 116, 105, 108, 101], "logprob": -5.1374755, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.22549273, "top_logprobs": []}, {"token": " widely", "bytes": [32, 119, 105, 100, 101, 108, 121], "logprob": -0.06408858, "top_logprobs": []}, {"token": " used", "bytes": [32, 117, 115, 101, 100], "logprob": -0.0007200573, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -0.24617708, "top_logprobs": []}, {"token": " data", "bytes": [32, 100, 97, 116, 97], "logprob": -2.132315, "top_logprobs": []}, {"token": " analysis", "bytes": [32, 97, 110, 97, 108, 121, 115, 105, 115], "logprob": -0.13818723, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.10035052, "top_logprobs": []}, {"token": " machine", "bytes": [32, 109, 97, 99, 104, 105, 110, 101], "logprob": -0.33921197, "top_logprobs": []}, {"token": " learning", "bytes": [32, 108, 101, 97, 114, 110, 105, 110, 103], "logprob": 0.0, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -9.138441e-05, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.005160108, "top_logprobs": []}, {"token": " monitoring", "bytes": [32, 109, 111, 110, 105, 116, 111, 114, 105, 110, 103], "logprob": -1.1899015, "top_logprobs": []}, {"token": " tasks", "bytes": [32, 116, 97, 115, 107, 115], "logprob": -0.50267833, "top_logprobs": []}, {"token": " due", "bytes": [32, 100, 117, 101], "logprob": -1.4371668, "top_logprobs": []}, {"token": " to", "bytes": [32, 116, 111], "logprob": -8.418666e-06, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -6.9882217e-06, "top_logprobs": []}, {"token": " ease", "bytes": [32, 101, 97, 115, 101], "logprob": -3.047081, "top_logprobs": []}, {"token": " of", "bytes": [32, 111, 102], "logprob": -5.5265704e-05, "top_logprobs": []}, {"token": " use", "bytes": [32, 117, 115, 101], "logprob": -0.000113794704, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.0067897392, "top_logprobs": []}, {"token": " robust", "bytes": [32, 114, 111, 98, 117, 115, 116], "logprob": -2.8805246, "top_logprobs": []}, {"token": " libraries", "bytes": [32, 108, 105, 98, 114, 97, 114, 105, 101, 115], "logprob": -0.05405761, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.0008492942, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Python is versatile and widely used for data analysis, machine learning, and monitoring tasks due to its ease of use and robust libraries.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752173, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "fawfaw", "usage": {"completion_tokens": 28, "prompt_tokens": 45, "total_tokens": 73, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYR545NY6eo8xAdKuEbE60YgrF1a", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -0.00012582695, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -2.8206474e-05, "top_logprobs": []}, {"token": " Rust", "bytes": [32, 82, 117, 115, 116], "logprob": -0.0001658757, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -0.0629955, "top_logprobs": []}, {"token": " known", "bytes": [32, 107, 110, 111, 119, 110], "logprob": -0.12907438, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -6.704273e-07, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -0.053365633, "top_logprobs": []}, {"token": " performance", "bytes": [32, 112, 101, 114, 102, 111, 114, 109, 97, 110, 99, 101], "logprob": -0.23435384, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0018123905, "top_logprobs": []}, {"token": " memory", "bytes": [32, 109, 101, 109, 111, 114, 121], "logprob": -0.97972125, "top_logprobs": []}, {"token": " safety", "bytes": [32, 115, 97, 102, 101, 116, 121], "logprob": -1.8908588e-05, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0049610855, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -5.3954464e-05, "top_logprobs": []}, {"token": " concurrency", "bytes": [32, 99, 111, 110, 99, 117, 114, 114, 101, 110, 99, 121], "logprob": -0.08939207, "top_logprobs": []}, {"token": " capabilities", "bytes": [32, 99, 97, 112, 97, 98, 105, 108, 105, 116, 105, 101, 115], "logprob": -2.2125401, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.38892052, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Rust is known for its performance, memory safety, and concurrency capabilities.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752199, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "wafwafwf", "usage": {"completion_tokens": 16, "prompt_tokens": 46, "total_tokens": 62, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
Loading

0 comments on commit 088291b

Please sign in to comment.