Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handling json file containing responses from a generative model #208

Merged
merged 8 commits into from
Dec 12, 2024
2 changes: 2 additions & 0 deletions api/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from app.db.tables.reference_dataset_metrics_table import *
from app.db.tables.current_dataset_table import *
from app.db.tables.current_dataset_metrics_table import *
from app.db.tables.completion_dataset_table import *
from app.db.tables.completion_dataset_metrics_table import *
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict
from app.db.database import Database, BaseTable

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""add_dataset_and_metrics_completion

Revision ID: e72dc7aaa4cc
Revises: dccb82489f4d
Create Date: 2024-12-11 13:33:38.759485

"""
from typing import Sequence, Union, Text

from alembic import op
import sqlalchemy as sa
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict

# revision identifiers, used by Alembic.
revision: str = 'e72dc7aaa4cc'
down_revision: Union[str, None] = 'dccb82489f4d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('completion_dataset',
sa.Column('UUID', sa.UUID(), nullable=False),
sa.Column('MODEL_UUID', sa.UUID(), nullable=False),
sa.Column('PATH', sa.VARCHAR(), nullable=False),
sa.Column('DATE', sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column('STATUS', sa.VARCHAR(), nullable=False),
sa.ForeignKeyConstraint(['MODEL_UUID'], ['model.UUID'], name=op.f('fk_completion_dataset_MODEL_UUID_model')),
sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset'))
)
op.create_table('completion_dataset_metrics',
sa.Column('UUID', sa.UUID(), nullable=False),
sa.Column('COMPLETION_UUID', sa.UUID(), nullable=False),
sa.Column('MODEL_QUALITY', JSONEncodedDict(astext_type=Text()), nullable=True),
sa.ForeignKeyConstraint(['COMPLETION_UUID'], ['completion_dataset.UUID'], name=op.f('fk_completion_dataset_metrics_COMPLETION_UUID_completion_dataset')),
sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset_metrics'))
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('completion_dataset_metrics')
op.drop_table('completion_dataset')
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/app/core/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SparkConfig(BaseSettings):
spark_image_pull_policy: str = 'IfNotPresent'
spark_reference_app_path: str = 'local:///opt/spark/custom_jobs/reference_job.py'
spark_current_app_path: str = 'local:///opt/spark/custom_jobs/current_job.py'
spark_completion_app_path: str = 'local:///opt/spark/custom_jobs/completion_job.py'
spark_namespace: str = 'spark'
spark_service_account: str = 'spark'

Expand Down
87 changes: 87 additions & 0 deletions api/app/db/dao/completion_dataset_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import re
from typing import List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import asc, desc
from sqlalchemy.future import select as future_select

from app.db.database import Database
from app.db.tables.completion_dataset_table import CompletionDataset
from app.models.dataset_dto import OrderType


class CompletionDatasetDAO:
def __init__(self, database: Database) -> None:
self.db = database

def insert_completion_dataset(
self, completion_dataset: CompletionDataset
) -> CompletionDataset:
with self.db.begin_session() as session:
session.add(completion_dataset)
session.flush()
return completion_dataset

def get_completion_dataset_by_model_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> Optional[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.where(
CompletionDataset.model_uuid == model_uuid,
CompletionDataset.uuid == completion_uuid,
)
.one_or_none()
)

def get_latest_completion_dataset_by_model_uuid(
self, model_uuid: UUID
) -> Optional[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.order_by(desc(CompletionDataset.date))
.where(CompletionDataset.model_uuid == model_uuid)
.limit(1)
.one_or_none()
)

def get_all_completion_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> List[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.order_by(desc(CompletionDataset.date))
.where(CompletionDataset.model_uuid == model_uuid)
)

def get_all_completion_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
) -> Page[CompletionDataset]:
def order_by_column_name(column_name):
return CompletionDataset.__getattribute__(
CompletionDataset, re.sub('(?=[A-Z])', '_', column_name).lower()
)

with self.db.begin_session() as session:
stmt = future_select(CompletionDataset).where(
CompletionDataset.model_uuid == model_uuid
)

if sort:
stmt = (
stmt.order_by(asc(order_by_column_name(sort)))
if order == OrderType.ASC
else stmt.order_by(desc(order_by_column_name(sort)))
)

return paginate(session, stmt, params)
26 changes: 26 additions & 0 deletions api/app/db/tables/completion_dataset_metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from uuid import uuid4

from sqlalchemy import UUID, Column, ForeignKey

from app.db.dao.base_dao import BaseDAO
from app.db.database import BaseTable, Reflected
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict


class CompletionDatasetMetrics(Reflected, BaseTable, BaseDAO):
__tablename__ = 'completion_dataset_metrics'

uuid = Column(
'UUID',
UUID(as_uuid=True),
nullable=False,
default=uuid4,
primary_key=True,
)
completion_uuid = Column(
'COMPLETION_UUID',
UUID(as_uuid=True),
ForeignKey('completion_dataset.UUID'),
nullable=False,
)
model_quality = Column('MODEL_QUALITY', JSONEncodedDict, nullable=True)
25 changes: 25 additions & 0 deletions api/app/db/tables/completion_dataset_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from uuid import uuid4

from sqlalchemy import TIMESTAMP, UUID, VARCHAR, Column, ForeignKey

from app.db.dao.base_dao import BaseDAO
from app.db.database import BaseTable, Reflected
from app.models.job_status import JobStatus


class CompletionDataset(Reflected, BaseTable, BaseDAO):
__tablename__ = 'completion_dataset'

uuid = Column(
'UUID',
UUID(as_uuid=True),
nullable=False,
default=uuid4,
primary_key=True,
)
model_uuid = Column(
'MODEL_UUID', UUID(as_uuid=True), ForeignKey('model.UUID'), nullable=False
)
path = Column('PATH', VARCHAR, nullable=False)
date = Column('DATE', TIMESTAMP(timezone=True), nullable=False)
status = Column('STATUS', VARCHAR, nullable=False, default=JobStatus.IMPORTING)
3 changes: 3 additions & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.middleware.cors import CORSMiddleware

from app.core import get_config
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
Expand Down Expand Up @@ -54,6 +55,7 @@
reference_dataset_metrics_dao = ReferenceDatasetMetricsDAO(database)
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)
completion_dataset_dao = CompletionDatasetDAO(database)

model_service = ModelService(
model_dao=model_dao,
Expand Down Expand Up @@ -81,6 +83,7 @@
file_service = FileService(
reference_dataset_dao,
current_dataset_dao,
completion_dataset_dao,
model_service,
s3_client,
spark_k8s_client,
Expand Down
78 changes: 78 additions & 0 deletions api/app/models/completion_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, List, Optional

from pydantic import BaseModel, RootModel, model_validator


class TokenLogProbs(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[Dict[str, float]]


class LogProbs(BaseModel):
content: List[TokenLogProbs]
refusal: Optional[str] = None


class Message(BaseModel):
content: str
refusal: Optional[str] = None
role: str
tool_calls: List = []
parsed: Optional[dict] = None


class Choice(BaseModel):
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
message: Message

@model_validator(mode='after')
def validate_logprobs(self):
if self.logprobs is None:
raise ValueError(
"the 'logprobs' field cannot be empty, metrics could not be computed."
)
return self


class UsageDetails(BaseModel):
accepted_prediction_tokens: int = 0
reasoning_tokens: int = 0
rejected_prediction_tokens: int = 0
audio_tokens: Optional[int] = None
cached_tokens: Optional[int] = None


class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
completion_tokens_details: UsageDetails
prompt_tokens_details: Optional[UsageDetails] = None


class Completion(BaseModel):
id: str
choices: List[Choice]
created: int
model: str
object: str
system_fingerprint: str
usage: Usage


class CompletionResponses(RootModel[List[Completion]]):
@model_validator(mode='before')
@classmethod
def handle_single_completion(cls, data):
"""If a single object is passed instead of a list, wrap it into a list."""
if isinstance(data, dict):
return [data]
if isinstance(data, list):
return data
raise ValueError(
'Input file must be a list of completion json or a single completion json'
)
23 changes: 23 additions & 0 deletions api/app/models/dataset_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.reference_dataset_table import ReferenceDataset

Expand Down Expand Up @@ -55,6 +56,28 @@ def from_current_dataset(cd: CurrentDataset) -> 'CurrentDatasetDTO':
)


class CompletionDatasetDTO(BaseModel):
uuid: UUID
model_uuid: UUID
path: str
date: str
status: str

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@staticmethod
def from_completion_dataset(cd: CompletionDataset) -> 'CompletionDatasetDTO':
return CompletionDatasetDTO(
uuid=cd.uuid,
model_uuid=cd.model_uuid,
path=cd.path,
date=cd.date.isoformat(),
status=cd.status,
)


class FileReference(BaseModel):
file_url: str
separator: str = ','
Expand Down
Loading