Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add load completion file in the sdk module #215

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/app/models/dataset_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ class FileReference(BaseModel):
)


class FileCompletion(BaseModel):
file_url: str

model_config = ConfigDict(
populate_by_name=True,
alias_generator=to_camel,
)


class OrderType(str, Enum):
ASC = 'asc'
DESC = 'desc'
11 changes: 11 additions & 0 deletions api/app/routes/upload_dataset_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.models.dataset_dto import (
CompletionDatasetDTO,
CurrentDatasetDTO,
FileCompletion,
FileReference,
OrderType,
ReferenceDatasetDTO,
Expand Down Expand Up @@ -75,6 +76,16 @@ def upload_completion_file(
) -> CompletionDatasetDTO:
return file_service.upload_completion_file(model_uuid, json_file)

@router.post(
'/{model_uuid}/completion/bind',
status_code=status.HTTP_200_OK,
response_model=CompletionDatasetDTO,
)
def bind_completion_file(
model_uuid: UUID, file_completion: FileCompletion
) -> CompletionDatasetDTO:
return file_service.bind_completion_file(model_uuid, file_completion)

@router.get(
'/{model_uuid}/reference',
status_code=200,
Expand Down
51 changes: 51 additions & 0 deletions api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from app.models.dataset_dto import (
CompletionDatasetDTO,
CurrentDatasetDTO,
FileCompletion,
FileReference,
OrderType,
ReferenceDatasetDTO,
Expand Down Expand Up @@ -398,6 +399,56 @@ def upload_completion_file(
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

def bind_completion_file(
self, model_uuid: UUID, file_completion: FileCompletion
) -> CompletionDatasetDTO:
model_out = self.model_svc.get_model_by_uuid(model_uuid)
if not model_out:
logger.error('Model %s not found', model_uuid)
raise ModelNotFoundError(f'Model {model_uuid} not found')
try:
url_parts = file_completion.file_url.replace('s3://', '').split('/')
self.s3_client.head_object(Bucket=url_parts[0], Key='/'.join(url_parts[1:]))

inserted_file = self.completion_dataset_dao.insert_completion_dataset(
CompletionDataset(
uuid=uuid4(),
model_uuid=model_uuid,
path=file_completion.file_url,
date=datetime.datetime.now(tz=datetime.UTC),
status=JobStatus.IMPORTING,
)
)
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_completion_app_path,
app_arguments=[
file_completion.file_url.replace('s3://', 's3a://'),
str(inserted_file.uuid),
CompletionDatasetMetrics.__tablename__,
CompletionDataset.__tablename__,
],
)

return CompletionDatasetDTO.from_completion_dataset(inserted_file)

except NoCredentialsError as nce:
raise HTTPException(
status_code=500, detail='S3 credentials not available'
) from nce
except ClientError as e:
if e.response['Error']['Code'] == '404':
raise HTTPException(
status_code=404,
detail=f'File {file_completion.file_url} not exists',
) from None
raise HTTPException(status_code=500, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

def get_all_reference_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
Expand Down
3 changes: 2 additions & 1 deletion sdk/radicalbit_platform_sdk/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .model_current_dataset import ModelCurrentDataset
from .model_reference_dataset import ModelReferenceDataset
from .model_completion_dataset import ModelCompletionDataset
from .model import Model

__all__ = ['Model', 'ModelCurrentDataset', 'ModelReferenceDataset']
__all__ = ['Model', 'ModelCurrentDataset', 'ModelReferenceDataset', 'ModelCompletionDataset']
123 changes: 121 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from io import BytesIO
import json
import os
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

import boto3
Expand All @@ -8,12 +10,18 @@
from pydantic import TypeAdapter, ValidationError
import requests

from radicalbit_platform_sdk.apis import ModelCurrentDataset, ModelReferenceDataset
from radicalbit_platform_sdk.apis import (
ModelCompletionDataset,
ModelCurrentDataset,
ModelReferenceDataset,
)
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
AwsCredentials,
ColumnDefinition,
CompletionFileUpload,
CompletionResponses,
CurrentFileUpload,
DataType,
FileReference,
Expand All @@ -24,6 +32,7 @@
OutputType,
ReferenceFileUpload,
)
from radicalbit_platform_sdk.models.file_upload_result import FileCompletion


class Model:
Expand Down Expand Up @@ -504,7 +513,117 @@ def __callback(response: requests.Response) -> ModelCurrentDataset:
data=file_ref.model_dump_json(),
)

def load_completion_dataset(
self,
file_name: str,
bucket: str,
object_name: Optional[str] = None,
aws_credentials: Optional[AwsCredentials] = None,
) -> ModelCompletionDataset:
"""Upload completion dataset to an S3 bucket and then bind it inside the platform.

Raises `ClientError` in case S3 upload fails.

:param file_name: The name of the completion file.
:param bucket: The name of the S3 bucket.
:param object_name: The optional name of the object uploaded to S3. Default value is None.
:param aws_credentials: AWS credentials used to connect to S3 bucket. Default value is None.
:return: An instance of `ModelCompletionDataset` representing the completion dataset
"""

try:
with open(file_name, encoding='utf-8') as f:
raw_json = json.load(f)
validated_json_bytes = self.__validate_json(raw_json)
except Exception as e:
raise ClientError(
f'Failed to validate JSON file {file_name}: {str(e)}'
) from e

if object_name is None:
object_name = f'{self.__uuid}/completion/{os.path.basename(file_name)}'

try:
s3_client = boto3.client(
's3',
aws_access_key_id=(
None if aws_credentials is None else aws_credentials.access_key_id
),
aws_secret_access_key=(
None
if aws_credentials is None
else aws_credentials.secret_access_key
),
region_name=(
None if aws_credentials is None else aws_credentials.default_region
),
endpoint_url=(
None
if aws_credentials is None
else (
None
if aws_credentials.endpoint_url is None
else aws_credentials.endpoint_url
)
),
)

s3_client.upload_fileobj(
validated_json_bytes,
bucket,
object_name,
ExtraArgs={
'Metadata': {
'model_uuid': str(self.__uuid),
'model_name': self.__name,
'file_type': 'completion',
}
},
)
except BotoClientError as e:
raise ClientError(
f'Unable to upload file {file_name} to remote storage: {e}'
) from e
return self.__bind_completion_dataset(f's3://{bucket}/{object_name}')

def __bind_completion_dataset(
self,
dataset_url: str,
) -> ModelCompletionDataset:
def __callback(response: requests.Response) -> ModelCompletionDataset:
try:
response = CompletionFileUpload.model_validate(response.json())
return ModelCompletionDataset(
self.__base_url, self.__uuid, self.__model_type, response
)
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e

file_completion = FileCompletion(
file_url=dataset_url,
)

return invoke(
method='POST',
url=f'{self.__base_url}/api/models/{str(self.__uuid)}/completion/bind',
valid_response_code=200,
func=__callback,
data=file_completion.model_dump_json(),
)

def __required_headers(self) -> List[str]:
model_columns = self.__features + self.__outputs.output
model_columns.append(self.__target)
return [model_column.name for model_column in model_columns]

@staticmethod
def __validate_json(json_data: List[Dict]) -> BytesIO:
try:
validated_data = CompletionResponses.model_validate(json_data)
return BytesIO(validated_data.model_dump_json().encode())
except ValidationError as e:
raise ClientError(f'JSON validation error: {str(e)}') from e
except Exception as e:
raise ClientError(
f'Unexpected error during JSON validation: {str(e)}'
) from e
103 changes: 103 additions & 0 deletions sdk/radicalbit_platform_sdk/apis/model_completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Optional
from uuid import UUID

from pydantic import ValidationError
import requests

from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
CompletionFileUpload,
CompletionTextGenerationModelQuality,
JobStatus,
ModelQuality,
ModelType,
)


class ModelCompletionDataset:
def __init__(
self,
base_url: str,
model_uuid: UUID,
model_type: ModelType,
upload: CompletionFileUpload,
) -> None:
self.__base_url = base_url
self.__model_uuid = model_uuid
self.__model_type = model_type
self.__uuid = upload.uuid
self.__path = upload.path
self.__date = upload.date
self.__status = upload.status
self.__model_metrics = None

def uuid(self) -> UUID:
return self.__uuid

def path(self) -> str:
return self.__path

def date(self) -> str:
return self.__date

def status(self) -> str:
return self.__status

def model_quality(self) -> Optional[ModelQuality]:
"""Get model quality metrics about the completion dataset

:return: The `ModelQuality` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[ModelQuality]]:
try:
response_json = response.json()
job_status = JobStatus(response_json['jobStatus'])
if 'modelQuality' in response_json:
match self.__model_type:
case ModelType.TEXT_GENERATION:
return (
job_status,
CompletionTextGenerationModelQuality.model_validate(
response_json['modelQuality']
),
)
case _:
raise ClientError(
'Unable to parse metrics because of not managed model type'
) from None
except KeyError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
else:
return job_status, None

match self.__status:
case JobStatus.ERROR:
self.__model_metrics = None
case JobStatus.MISSING_COMPLETION:
self.__model_metrics = None
case JobStatus.SUCCEEDED:
if self.__model_metrics is None:
_, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/completion/{str(self.__uuid)}/model-quality',
valid_response_code=200,
func=__callback,
)
self.__model_metrics = metrics
case JobStatus.IMPORTING:
status, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/completion/{str(self.__uuid)}/model-quality',
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__model_metrics = metrics

return self.__model_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
if data.current_data:
assert len(data.current_data) <= 2

reference_json_data = [binary_data.model_dump() for binary_data in data.reference_data]
current_data_json = [binary_data.model_dump() for binary_data in data.current_data] if data.current_data else []
reference_json_data = [
binary_data.model_dump() for binary_data in data.reference_data
]
current_data_json = (
[binary_data.model_dump() for binary_data in data.current_data]
if data.current_data
else []
)

reference_series_data = {
'title': data.title,
Expand Down Expand Up @@ -87,7 +93,6 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
return EChartsRawWidget(option=option)

def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget:

reference_series_data = {
'name': 'Reference',
'type': 'line',
Expand Down
Loading