Skip to content

Commit

Permalink
feat: completion spark job (#210)
Browse files Browse the repository at this point in the history
* llm_metrics job with tests

* feat(spark): add completion spark job

* feat: add spark-test service

* fix(spark): completion job entry point

* fix(spark): minor fix

* fix(spark): renamed completion classes

---------

Co-authored-by: carlopignatiello <[email protected]>
  • Loading branch information
carlo-pignatiello and carlopignatiello authored Dec 13, 2024
1 parent 7821c9a commit ff2f911
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 2 deletions.
10 changes: 10 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ services:
start_period: 5s
retries: 2

spark-test:
profiles: ["spark-test"]
build: ./spark-test
environment:
JOB_NAME: "completion"
KUBECONFIG_FILE_PATH: "/opt/kubeconfig/kubeconfig.yaml"
SPARK_IMAGE: "radicalbit-spark-py:develop"
volumes:
- ./docker/k3s_data/kubeconfig/kubeconfig.yaml:/opt/kubeconfig/kubeconfig.yaml

dind:
image: docker:dind
privileged: true
Expand Down
11 changes: 11 additions & 0 deletions spark-test/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11.8-slim

WORKDIR /spark-test

COPY requirements.txt requirements.txt

RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["python3", "main.py"]
14 changes: 14 additions & 0 deletions spark-test/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def create_secrets():
return {
"AWS_ACCESS_KEY_ID": "minio",
"AWS_SECRET_ACCESS_KEY": "minio123",
"AWS_REGION": "us-east-1",
"S3_ENDPOINT_URL": "http://minio:9000",
"POSTGRES_URL": "jdbc:postgresql://postgres:5432/radicalbit",
"POSTGRES_DB": "radicalbit",
"POSTGRES_HOST": "postgres",
"POSTGRES_PORT": "5432",
"POSTGRES_USER": "postgres",
"POSTGRES_PASSWORD": "postgres",
"POSTGRES_SCHEMA": "public",
}
37 changes: 37 additions & 0 deletions spark-test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
from conf import create_secrets
from uuid import uuid4
from spark_on_k8s.k8s.sync_client import KubernetesClientManager
from spark_on_k8s.client import SparkOnK8S

envs = ["KUBECONFIG_FILE_PATH", "JOB_NAME", "SPARK_IMAGE"]

for var in envs:
if var not in os.environ:
raise EnvironmentError("Failed because {} is not set.".format(var))

kube_conf = os.environ["KUBECONFIG_FILE_PATH"]
job_name = os.environ["JOB_NAME"]
spark_image = os.environ["SPARK_IMAGE"]

k8s_client_manager = KubernetesClientManager(kube_conf)
spark_k8s_client = SparkOnK8S(k8s_client_manager=k8s_client_manager)

path = "s3a://test-bucket/metrics_one.json"

spark_k8s_client.submit_app(
image=spark_image,
app_path=f"local:///opt/spark/custom_jobs/{job_name}_job.py",
app_arguments=[
path,
str(uuid4()),
"completion_dataset_metrics",
"completion_dataset"
],
app_name=f"{spark_image}-completion-job",
namespace="spark",
service_account="spark",
app_waiter="no_wait",
image_pull_policy="IfNotPresent",
secret_values=create_secrets(),
)
1 change: 1 addition & 0 deletions spark-test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
spark-on-k8s==0.10.1
10 changes: 8 additions & 2 deletions spark/jobs/completion_job.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sys
import os
import uuid

import orjson
from pyspark.sql.types import StructField, StructType, StringType

from metrics.completion_metrics import CompletionMetrics
from utils.models import JobStatus
from utils.db import update_job_status, write_to_db

Expand All @@ -13,7 +15,11 @@

def compute_metrics(df: DataFrame) -> dict:
complete_record = {}
# TODO: compute model quality metrics
completion_service = CompletionMetrics()
model_quality = completion_service.extract_metrics(df)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode(
"utf-8"
)
return complete_record


Expand Down
85 changes: 85 additions & 0 deletions spark/jobs/metrics/completion_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pyspark.sql.functions as F
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType

from models.completion_dataset import CompletionMetricsModel


class CompletionMetrics:
def __init__(self):
pass

@staticmethod
@F.udf(FloatType())
def compute_probability_udf(log_prob: float) -> float:
return float(np.exp(log_prob))

@staticmethod
@F.udf(FloatType())
def compute_perplexity(log_probs: list[float]) -> float:
return float(np.exp(-np.mean(log_probs)))

@staticmethod
@F.udf(FloatType())
def compute_prob_mean_per_phrase(probs: list[float]) -> float:
return float(np.mean(probs))

@staticmethod
def remove_columns(df: DataFrame) -> DataFrame:
df = df.drop(
"model",
"object",
"created",
"system_fingerprint",
"usage",
"service_tier",
)
return df

def compute_prob(self, df: DataFrame):
df = df.select(F.explode("choices").alias("element"), F.col("id"))
df = df.select(
F.col("id"), F.explode("element.logprobs.content").alias("content")
)
df = df.select("id", "content.logprob", "content.token").withColumn(
"prob", self.compute_probability_udf("logprob")
)
return df

def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
df_prob = df_prob.groupBy("id").agg(
F.collect_list(F.struct("token", "prob")).alias("probs")
)
df_mean_values = df.groupBy("id").agg(
self.compute_prob_mean_per_phrase(F.collect_list("prob")).alias(
"prob_per_phrase"
),
self.compute_perplexity(F.collect_list("logprob")).alias(
"perplex_per_phrase"
),
)
df = df_mean_values.agg(
F.mean("prob_per_phrase").alias("prob_tot_mean"),
F.mean("perplex_per_phrase").alias("perplex_tot_mean"),
)
tokens = [
{
"id": row["id"],
"probs": [
{"token": prob["token"], "prob": prob["prob"]}
for prob in row["probs"]
],
}
for row in df_prob.toLocalIterator()
]

res = {
"tokens": tokens,
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"),
"mean_per_file": df.toPandas().to_dict(orient="records"),
}
return CompletionMetricsModel(**res)
35 changes: 35 additions & 0 deletions spark/jobs/models/completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import BaseModel, confloat, ConfigDict
from typing import List


class Prob(BaseModel):
token: str
prob: confloat(ge=0, le=1)


class Probs(BaseModel):
id: str
probs: List[Prob]

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: confloat(ge=0, le=1)
perplex_per_phrase: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class MeanPerFile(BaseModel):
prob_tot_mean: confloat(ge=0, le=1)
perplex_tot_mean: confloat(ge=1)

model_config = ConfigDict(ser_json_inf_nan="null")


class CompletionMetricsModel(BaseModel):
tokens: List[Probs]
mean_per_phrase: List[MeanPerPhrase]
mean_per_file: List[MeanPerFile]
45 changes: 45 additions & 0 deletions spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import orjson
from jobs.completion_job import compute_metrics
from jobs.metrics.completion_metrics import CompletionMetrics
from jobs.models.completion_dataset import CompletionMetricsModel
from tests.results.completion_metrics_results import completion_metric_results


@pytest.fixture
def input_file(spark_fixture, test_data_dir):
yield spark_fixture.read.option("multiline", "true").json(
f"{test_data_dir}/completion/metrics.json"
)


def test_remove_columns(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
assert "id" in df.columns
assert "choices" in df.columns
assert len(df.columns) == 2


def test_compute_prob(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
df = completion_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


def test_extract_metrics(spark_fixture, input_file):
completion_metrics_service = CompletionMetrics()
completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file)
assert len(completion_metrics_model.tokens) > 0
assert len(completion_metrics_model.mean_per_phrase) > 0
assert len(completion_metrics_model.mean_per_file) > 0


def test_compute_metrics(spark_fixture, input_file):
complete_record = compute_metrics(input_file)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == orjson.dumps(completion_metric_results).decode(
"utf-8"
)
4 changes: 4 additions & 0 deletions spark/tests/resources/completion/metrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{"id": "chatcmpl-AcWID2SsE5iuK6z5AhNCKv3WUcCxN", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.61251247, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.102561064, "top_logprobs": []}, {"token": " go", "bytes": [32, 103, 111], "logprob": -2.5411978, "top_logprobs": []}, {"token": " ahead", "bytes": [32, 97, 104, 101, 97, 100], "logprob": -0.0014073749, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -2.0402107, "top_logprobs": []}, {"token": " What's", "bytes": [32, 87, 104, 97, 116, 39, 115], "logprob": -0.8943377, "top_logprobs": []}, {"token": " up", "bytes": [32, 117, 112], "logprob": -0.08216706, "top_logprobs": []}, {"token": "?", "bytes": [63], "logprob": -4.978234e-05, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure, go ahead. What's up?", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733743961, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYMMPLnpkksCdLze3M8nnqQbfqVG", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Certainly", "bytes": [67, 101, 114, 116, 97, 105, 110, 108, 121], "logprob": -3.8160203, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.11697425, "top_logprobs": []}, {"token": " Just", "bytes": [32, 74, 117, 115, 116], "logprob": -5.9011784, "top_logprobs": []}, {"token": " let", "bytes": [32, 108, 101, 116], "logprob": -0.666558, "top_logprobs": []}, {"token": " me", "bytes": [32, 109, 101], "logprob": -5.574252e-05, "top_logprobs": []}, {"token": " know", "bytes": [32, 107, 110, 111, 119], "logprob": -0.0008052219, "top_logprobs": []}, {"token": " how", "bytes": [32, 104, 111, 119], "logprob": -0.7132411, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.034184996, "top_logprobs": []}], "refusal": null}, "message": {"content": "Certainly! Just let me know how.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733751906, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "afnfwuinawufwa", "usage": {"completion_tokens": 8, "prompt_tokens": 45, "total_tokens": 53, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
5 changes: 5 additions & 0 deletions spark/tests/resources/completion/metrics_one.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
{"id": "chatcmpl-AcYPBK1t4QUvRG1oiS2JjoK0xdrGH", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Sure", "bytes": [83, 117, 114, 101], "logprob": -0.46845868, "top_logprobs": []}, {"token": "!", "bytes": [33], "logprob": -0.82938546, "top_logprobs": []}, {"token": " Please", "bytes": [32, 80, 108, 101, 97, 115, 101], "logprob": -2.1082883, "top_logprobs": []}, {"token": " ask", "bytes": [32, 97, 115, 107], "logprob": -0.67591065, "top_logprobs": []}, {"token": " your", "bytes": [32, 121, 111, 117, 114], "logprob": -0.022492433, "top_logprobs": []}, {"token": " question", "bytes": [32, 113, 117, 101, 115, 116, 105, 111, 110], "logprob": -0.13909559, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -3.3818815, "top_logprobs": []}, {"token": " I'll", "bytes": [32, 73, 39, 108, 108], "logprob": -0.4059926, "top_logprobs": []}, {"token": " be", "bytes": [32, 98, 101], "logprob": -1.799196, "top_logprobs": []}, {"token": " concise", "bytes": [32, 99, 111, 110, 99, 105, 115, 101], "logprob": -0.23712571, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.03838387, "top_logprobs": []}], "refusal": null}, "message": {"content": "Sure! Please ask your question. I'll be concise.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752081, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "dwadawdwd", "usage": {"completion_tokens": 11, "prompt_tokens": 45, "total_tokens": 56, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYQfMsRAIA01ynfZAQb8Zmyc6WKp", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -5.5577775e-06, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -5.2285613e-05, "top_logprobs": []}, {"token": " Python", "bytes": [32, 80, 121, 116, 104, 111, 110], "logprob": -0.00089550123, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -4.246537e-06, "top_logprobs": []}, {"token": " versatile", "bytes": [32, 118, 101, 114, 115, 97, 116, 105, 108, 101], "logprob": -5.1374755, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.22549273, "top_logprobs": []}, {"token": " widely", "bytes": [32, 119, 105, 100, 101, 108, 121], "logprob": -0.06408858, "top_logprobs": []}, {"token": " used", "bytes": [32, 117, 115, 101, 100], "logprob": -0.0007200573, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -0.24617708, "top_logprobs": []}, {"token": " data", "bytes": [32, 100, 97, 116, 97], "logprob": -2.132315, "top_logprobs": []}, {"token": " analysis", "bytes": [32, 97, 110, 97, 108, 121, 115, 105, 115], "logprob": -0.13818723, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.10035052, "top_logprobs": []}, {"token": " machine", "bytes": [32, 109, 97, 99, 104, 105, 110, 101], "logprob": -0.33921197, "top_logprobs": []}, {"token": " learning", "bytes": [32, 108, 101, 97, 114, 110, 105, 110, 103], "logprob": 0.0, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -9.138441e-05, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.005160108, "top_logprobs": []}, {"token": " monitoring", "bytes": [32, 109, 111, 110, 105, 116, 111, 114, 105, 110, 103], "logprob": -1.1899015, "top_logprobs": []}, {"token": " tasks", "bytes": [32, 116, 97, 115, 107, 115], "logprob": -0.50267833, "top_logprobs": []}, {"token": " due", "bytes": [32, 100, 117, 101], "logprob": -1.4371668, "top_logprobs": []}, {"token": " to", "bytes": [32, 116, 111], "logprob": -8.418666e-06, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -6.9882217e-06, "top_logprobs": []}, {"token": " ease", "bytes": [32, 101, 97, 115, 101], "logprob": -3.047081, "top_logprobs": []}, {"token": " of", "bytes": [32, 111, 102], "logprob": -5.5265704e-05, "top_logprobs": []}, {"token": " use", "bytes": [32, 117, 115, 101], "logprob": -0.000113794704, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -0.0067897392, "top_logprobs": []}, {"token": " robust", "bytes": [32, 114, 111, 98, 117, 115, 116], "logprob": -2.8805246, "top_logprobs": []}, {"token": " libraries", "bytes": [32, 108, 105, 98, 114, 97, 114, 105, 101, 115], "logprob": -0.05405761, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.0008492942, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Python is versatile and widely used for data analysis, machine learning, and monitoring tasks due to its ease of use and robust libraries.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752173, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "fawfaw", "usage": {"completion_tokens": 28, "prompt_tokens": 45, "total_tokens": 73, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}},
{"id": "chatcmpl-AcYR545NY6eo8xAdKuEbE60YgrF1a", "choices": [{"finish_reason": "stop", "index": 0, "logprobs": {"content": [{"token": "Yes", "bytes": [89, 101, 115], "logprob": -0.00012582695, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -2.8206474e-05, "top_logprobs": []}, {"token": " Rust", "bytes": [32, 82, 117, 115, 116], "logprob": -0.0001658757, "top_logprobs": []}, {"token": " is", "bytes": [32, 105, 115], "logprob": -0.0629955, "top_logprobs": []}, {"token": " known", "bytes": [32, 107, 110, 111, 119, 110], "logprob": -0.12907438, "top_logprobs": []}, {"token": " for", "bytes": [32, 102, 111, 114], "logprob": -6.704273e-07, "top_logprobs": []}, {"token": " its", "bytes": [32, 105, 116, 115], "logprob": -0.053365633, "top_logprobs": []}, {"token": " performance", "bytes": [32, 112, 101, 114, 102, 111, 114, 109, 97, 110, 99, 101], "logprob": -0.23435384, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0018123905, "top_logprobs": []}, {"token": " memory", "bytes": [32, 109, 101, 109, 111, 114, 121], "logprob": -0.97972125, "top_logprobs": []}, {"token": " safety", "bytes": [32, 115, 97, 102, 101, 116, 121], "logprob": -1.8908588e-05, "top_logprobs": []}, {"token": ",", "bytes": [44], "logprob": -0.0049610855, "top_logprobs": []}, {"token": " and", "bytes": [32, 97, 110, 100], "logprob": -5.3954464e-05, "top_logprobs": []}, {"token": " concurrency", "bytes": [32, 99, 111, 110, 99, 117, 114, 114, 101, 110, 99, 121], "logprob": -0.08939207, "top_logprobs": []}, {"token": " capabilities", "bytes": [32, 99, 97, 112, 97, 98, 105, 108, 105, 116, 105, 101, 115], "logprob": -2.2125401, "top_logprobs": []}, {"token": ".", "bytes": [46], "logprob": -0.38892052, "top_logprobs": []}], "refusal": null}, "message": {"content": "Yes, Rust is known for its performance, memory safety, and concurrency capabilities.", "refusal": null, "role": "assistant", "tool_calls": [], "parsed": null}}], "created": 1733752199, "model": "gpt-4o-2024-08-06", "object": "chat.completion", "system_fingerprint": "wafwafwf", "usage": {"completion_tokens": 16, "prompt_tokens": 46, "total_tokens": 62, "completion_tokens_details": {"accepted_prediction_tokens": 0, "audio_tokens": 0, "reasoning_tokens": 0, "rejected_prediction_tokens": 0}, "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}}}
]
Loading

0 comments on commit ff2f911

Please sign in to comment.