Skip to content

Commit

Permalink
feat: handling json file containing responses from a generative model (
Browse files Browse the repository at this point in the history
…#208)

* feat: add text-generation as new model type, handle the new model type, set schema fields as optional, edit test

* fix: ruff check

* feat: add optional schema fields to models definition (sdk)

* feat: set optional fields model schema (spark side)

* handle json response

* feat: editing json handling
  • Loading branch information
dtria91 authored Dec 12, 2024
1 parent 40f0912 commit 5a82c90
Show file tree
Hide file tree
Showing 16 changed files with 908 additions and 4 deletions.
2 changes: 2 additions & 0 deletions api/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from app.db.tables.reference_dataset_metrics_table import *
from app.db.tables.current_dataset_table import *
from app.db.tables.current_dataset_metrics_table import *
from app.db.tables.completion_dataset_table import *
from app.db.tables.completion_dataset_metrics_table import *
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict
from app.db.database import Database, BaseTable

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""add_dataset_and_metrics_completion
Revision ID: e72dc7aaa4cc
Revises: dccb82489f4d
Create Date: 2024-12-11 13:33:38.759485
"""
from typing import Sequence, Union, Text

from alembic import op
import sqlalchemy as sa
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict

# revision identifiers, used by Alembic.
revision: str = 'e72dc7aaa4cc'
down_revision: Union[str, None] = 'dccb82489f4d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('completion_dataset',
sa.Column('UUID', sa.UUID(), nullable=False),
sa.Column('MODEL_UUID', sa.UUID(), nullable=False),
sa.Column('PATH', sa.VARCHAR(), nullable=False),
sa.Column('DATE', sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column('STATUS', sa.VARCHAR(), nullable=False),
sa.ForeignKeyConstraint(['MODEL_UUID'], ['model.UUID'], name=op.f('fk_completion_dataset_MODEL_UUID_model')),
sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset'))
)
op.create_table('completion_dataset_metrics',
sa.Column('UUID', sa.UUID(), nullable=False),
sa.Column('COMPLETION_UUID', sa.UUID(), nullable=False),
sa.Column('MODEL_QUALITY', JSONEncodedDict(astext_type=Text()), nullable=True),
sa.ForeignKeyConstraint(['COMPLETION_UUID'], ['completion_dataset.UUID'], name=op.f('fk_completion_dataset_metrics_COMPLETION_UUID_completion_dataset')),
sa.PrimaryKeyConstraint('UUID', name=op.f('pk_completion_dataset_metrics'))
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('completion_dataset_metrics')
op.drop_table('completion_dataset')
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/app/core/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class SparkConfig(BaseSettings):
spark_image_pull_policy: str = 'IfNotPresent'
spark_reference_app_path: str = 'local:///opt/spark/custom_jobs/reference_job.py'
spark_current_app_path: str = 'local:///opt/spark/custom_jobs/current_job.py'
spark_completion_app_path: str = 'local:///opt/spark/custom_jobs/completion_job.py'
spark_namespace: str = 'spark'
spark_service_account: str = 'spark'

Expand Down
87 changes: 87 additions & 0 deletions api/app/db/dao/completion_dataset_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import re
from typing import List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import asc, desc
from sqlalchemy.future import select as future_select

from app.db.database import Database
from app.db.tables.completion_dataset_table import CompletionDataset
from app.models.dataset_dto import OrderType


class CompletionDatasetDAO:
def __init__(self, database: Database) -> None:
self.db = database

def insert_completion_dataset(
self, completion_dataset: CompletionDataset
) -> CompletionDataset:
with self.db.begin_session() as session:
session.add(completion_dataset)
session.flush()
return completion_dataset

def get_completion_dataset_by_model_uuid(
self, model_uuid: UUID, completion_uuid: UUID
) -> Optional[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.where(
CompletionDataset.model_uuid == model_uuid,
CompletionDataset.uuid == completion_uuid,
)
.one_or_none()
)

def get_latest_completion_dataset_by_model_uuid(
self, model_uuid: UUID
) -> Optional[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.order_by(desc(CompletionDataset.date))
.where(CompletionDataset.model_uuid == model_uuid)
.limit(1)
.one_or_none()
)

def get_all_completion_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> List[CompletionDataset]:
with self.db.begin_session() as session:
return (
session.query(CompletionDataset)
.order_by(desc(CompletionDataset.date))
.where(CompletionDataset.model_uuid == model_uuid)
)

def get_all_completion_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
) -> Page[CompletionDataset]:
def order_by_column_name(column_name):
return CompletionDataset.__getattribute__(
CompletionDataset, re.sub('(?=[A-Z])', '_', column_name).lower()
)

with self.db.begin_session() as session:
stmt = future_select(CompletionDataset).where(
CompletionDataset.model_uuid == model_uuid
)

if sort:
stmt = (
stmt.order_by(asc(order_by_column_name(sort)))
if order == OrderType.ASC
else stmt.order_by(desc(order_by_column_name(sort)))
)

return paginate(session, stmt, params)
26 changes: 26 additions & 0 deletions api/app/db/tables/completion_dataset_metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from uuid import uuid4

from sqlalchemy import UUID, Column, ForeignKey

from app.db.dao.base_dao import BaseDAO
from app.db.database import BaseTable, Reflected
from app.db.tables.commons.json_encoded_dict import JSONEncodedDict


class CompletionDatasetMetrics(Reflected, BaseTable, BaseDAO):
__tablename__ = 'completion_dataset_metrics'

uuid = Column(
'UUID',
UUID(as_uuid=True),
nullable=False,
default=uuid4,
primary_key=True,
)
completion_uuid = Column(
'COMPLETION_UUID',
UUID(as_uuid=True),
ForeignKey('completion_dataset.UUID'),
nullable=False,
)
model_quality = Column('MODEL_QUALITY', JSONEncodedDict, nullable=True)
25 changes: 25 additions & 0 deletions api/app/db/tables/completion_dataset_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from uuid import uuid4

from sqlalchemy import TIMESTAMP, UUID, VARCHAR, Column, ForeignKey

from app.db.dao.base_dao import BaseDAO
from app.db.database import BaseTable, Reflected
from app.models.job_status import JobStatus


class CompletionDataset(Reflected, BaseTable, BaseDAO):
__tablename__ = 'completion_dataset'

uuid = Column(
'UUID',
UUID(as_uuid=True),
nullable=False,
default=uuid4,
primary_key=True,
)
model_uuid = Column(
'MODEL_UUID', UUID(as_uuid=True), ForeignKey('model.UUID'), nullable=False
)
path = Column('PATH', VARCHAR, nullable=False)
date = Column('DATE', TIMESTAMP(timezone=True), nullable=False)
status = Column('STATUS', VARCHAR, nullable=False, default=JobStatus.IMPORTING)
3 changes: 3 additions & 0 deletions api/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.middleware.cors import CORSMiddleware

from app.core import get_config
from app.db.dao.completion_dataset_dao import CompletionDatasetDAO
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.current_dataset_metrics_dao import CurrentDatasetMetricsDAO
from app.db.dao.model_dao import ModelDAO
Expand Down Expand Up @@ -54,6 +55,7 @@
reference_dataset_metrics_dao = ReferenceDatasetMetricsDAO(database)
current_dataset_dao = CurrentDatasetDAO(database)
current_dataset_metrics_dao = CurrentDatasetMetricsDAO(database)
completion_dataset_dao = CompletionDatasetDAO(database)

model_service = ModelService(
model_dao=model_dao,
Expand Down Expand Up @@ -81,6 +83,7 @@
file_service = FileService(
reference_dataset_dao,
current_dataset_dao,
completion_dataset_dao,
model_service,
s3_client,
spark_k8s_client,
Expand Down
78 changes: 78 additions & 0 deletions api/app/models/completion_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, List, Optional

from pydantic import BaseModel, RootModel, model_validator


class TokenLogProbs(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[Dict[str, float]]


class LogProbs(BaseModel):
content: List[TokenLogProbs]
refusal: Optional[str] = None


class Message(BaseModel):
content: str
refusal: Optional[str] = None
role: str
tool_calls: List = []
parsed: Optional[dict] = None


class Choice(BaseModel):
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
message: Message

@model_validator(mode='after')
def validate_logprobs(self):
if self.logprobs is None:
raise ValueError(
"the 'logprobs' field cannot be empty, metrics could not be computed."
)
return self


class UsageDetails(BaseModel):
accepted_prediction_tokens: int = 0
reasoning_tokens: int = 0
rejected_prediction_tokens: int = 0
audio_tokens: Optional[int] = None
cached_tokens: Optional[int] = None


class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
completion_tokens_details: UsageDetails
prompt_tokens_details: Optional[UsageDetails] = None


class Completion(BaseModel):
id: str
choices: List[Choice]
created: int
model: str
object: str
system_fingerprint: str
usage: Usage


class CompletionResponses(RootModel[List[Completion]]):
@model_validator(mode='before')
@classmethod
def handle_single_completion(cls, data):
"""If a single object is passed instead of a list, wrap it into a list."""
if isinstance(data, dict):
return [data]
if isinstance(data, list):
return data
raise ValueError(
'Input file must be a list of completion json or a single completion json'
)
23 changes: 23 additions & 0 deletions api/app/models/dataset_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel

from app.db.tables.completion_dataset_table import CompletionDataset
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.reference_dataset_table import ReferenceDataset

Expand Down Expand Up @@ -55,6 +56,28 @@ def from_current_dataset(cd: CurrentDataset) -> 'CurrentDatasetDTO':
)


class CompletionDatasetDTO(BaseModel):
uuid: UUID
model_uuid: UUID
path: str
date: str
status: str

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@staticmethod
def from_completion_dataset(cd: CompletionDataset) -> 'CompletionDatasetDTO':
return CompletionDatasetDTO(
uuid=cd.uuid,
model_uuid=cd.model_uuid,
path=cd.path,
date=cd.date.isoformat(),
status=cd.status,
)


class FileReference(BaseModel):
file_url: str
separator: str = ','
Expand Down
Loading

0 comments on commit 5a82c90

Please sign in to comment.