Skip to content

Commit

Permalink
fix(spark): renamed completion classes
Browse files Browse the repository at this point in the history
  • Loading branch information
carlopignatiello committed Dec 13, 2024
1 parent fe80945 commit 844d7b1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
4 changes: 2 additions & 2 deletions spark/jobs/completion_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import orjson
from pyspark.sql.types import StructField, StructType, StringType

from metrics.completion_metrics import LLMMetrics
from metrics.completion_metrics import CompletionMetrics
from utils.models import JobStatus
from utils.db import update_job_status, write_to_db

Expand All @@ -15,7 +15,7 @@

def compute_metrics(df: DataFrame) -> dict:
complete_record = {}
completion_service = LLMMetrics()
completion_service = CompletionMetrics()
model_quality = completion_service.extract_metrics(df)
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality.model_dump(serialize_as_any=True)).decode(
"utf-8"
Expand Down
9 changes: 5 additions & 4 deletions spark/jobs/metrics/completion_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import numpy as np
from pyspark.sql import DataFrame
from pyspark.sql.types import FloatType
from models.completion_dataset import LLMMetricsModel

from models.completion_dataset import CompletionMetricsModel

class LLMMetrics:

class CompletionMetrics:
def __init__(self):
pass

Expand Down Expand Up @@ -46,7 +47,7 @@ def compute_prob(self, df: DataFrame):
)
return df

def extract_metrics(self, df: DataFrame) -> LLMMetricsModel:
def extract_metrics(self, df: DataFrame) -> CompletionMetricsModel:
df = self.remove_columns(df)
df = self.compute_prob(df)
df_prob = df.drop("logprob")
Expand Down Expand Up @@ -81,4 +82,4 @@ def extract_metrics(self, df: DataFrame) -> LLMMetricsModel:
"mean_per_phrase": df_mean_values.toPandas().to_dict(orient="records"),
"mean_per_file": df.toPandas().to_dict(orient="records"),
}
return LLMMetricsModel(**res)
return CompletionMetricsModel(**res)
2 changes: 1 addition & 1 deletion spark/jobs/models/completion_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MeanPerFile(BaseModel):
model_config = ConfigDict(ser_json_inf_nan="null")


class LLMMetricsModel(BaseModel):
class CompletionMetricsModel(BaseModel):
tokens: List[Probs]
mean_per_phrase: List[MeanPerPhrase]
mean_per_file: List[MeanPerFile]
29 changes: 16 additions & 13 deletions spark/tests/completion_metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import orjson
from jobs.completion_job import compute_metrics
from jobs.metrics.completion_metrics import LLMMetrics
from jobs.models.completion_dataset import LLMMetricsModel
from jobs.metrics.completion_metrics import CompletionMetrics
from jobs.models.completion_dataset import CompletionMetricsModel
from tests.results.completion_metrics_results import completion_metric_results


Expand All @@ -13,30 +14,32 @@ def input_file(spark_fixture, test_data_dir):


def test_remove_columns(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(input_file)
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
assert "id" in df.columns
assert "choices" in df.columns
assert len(df.columns) == 2


def test_compute_prob(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
df = llm_metrics_service.remove_columns(input_file)
df = llm_metrics_service.compute_prob(df)
completion_metrics_service = CompletionMetrics()
df = completion_metrics_service.remove_columns(input_file)
df = completion_metrics_service.compute_prob(df)
assert {"id", "logprob", "token", "prob"} == set(df.columns)
assert not df.rdd.isEmpty()


def test_extract_metrics(spark_fixture, input_file):
llm_metrics_service = LLMMetrics()
llm_metrics_model: LLMMetricsModel = llm_metrics_service.extract_metrics(input_file)
assert len(llm_metrics_model.tokens) > 0
assert len(llm_metrics_model.mean_per_phrase) > 0
assert len(llm_metrics_model.mean_per_file) > 0
completion_metrics_service = CompletionMetrics()
completion_metrics_model: CompletionMetricsModel = completion_metrics_service.extract_metrics(input_file)
assert len(completion_metrics_model.tokens) > 0
assert len(completion_metrics_model.mean_per_phrase) > 0
assert len(completion_metrics_model.mean_per_file) > 0


def test_compute_metrics(spark_fixture, input_file):
complete_record = compute_metrics(input_file)
model_quality = complete_record.get("MODEL_QUALITY")
assert model_quality == completion_metric_results
assert model_quality == orjson.dumps(completion_metric_results).decode(
"utf-8"
)

0 comments on commit 844d7b1

Please sign in to comment.