diff --git a/sdk/radicalbit_platform_sdk/apis/model.py b/sdk/radicalbit_platform_sdk/apis/model.py index 8a415489..b57bd461 100644 --- a/sdk/radicalbit_platform_sdk/apis/model.py +++ b/sdk/radicalbit_platform_sdk/apis/model.py @@ -155,6 +155,28 @@ def __callback(response: requests.Response) -> List[ModelCurrentDataset]: func=__callback, ) + def get_completion_datasets(self) -> List[ModelCompletionDataset]: + def __callback(response: requests.Response) -> List[ModelCompletionDataset]: + try: + adapter = TypeAdapter(List[CompletionFileUpload]) + completions = adapter.validate_python(response.json()) + + return [ + ModelCompletionDataset( + self.__base_url, self.__uuid, self.__model_type, completion + ) + for completion in completions + ] + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + + return invoke( + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__uuid)}/completion/all', + valid_response_code=200, + func=__callback, + ) + def load_reference_dataset( self, file_name: str, diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index 5c705542..623f51dc 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -191,5 +191,34 @@ class CurrentRegressionModelQuality(ModelQuality): model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) +class TokenProb(BaseModel): + prob: float + token: str + + +class TokenData(BaseModel): + id: str + probs: List[TokenProb] + + +class MeanPerFile(BaseModel): + prob_tot_mean: float + perplex_tot_mean: float + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + +class MeanPerPhrase(BaseModel): + id: str + prob_per_phrase: float + perplex_per_phrase: float + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) + + class CompletionTextGenerationModelQuality(ModelQuality): - pass + tokens: List[TokenData] + mean_per_file: List[MeanPerFile] + mean_per_phrase: List[MeanPerPhrase] + + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) diff --git a/sdk/tests/apis/model_completion_dataset_test.py b/sdk/tests/apis/model_completion_dataset_test.py index c5ad752c..7b60c895 100644 --- a/sdk/tests/apis/model_completion_dataset_test.py +++ b/sdk/tests/apis/model_completion_dataset_test.py @@ -39,13 +39,57 @@ def test_text_generation_model_quality_ok(self): body="""{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", - "modelQuality": {} - }""", + "modelQuality": + { + "tokens": [ + { + "id":"chatcmpl", + "probs":[ + { + "prob":0.27, + "token":"Sky" + }, + { + "prob":0.89, + "token":" is" + }, + { + "prob":0.70, + "token":" blue" + }, + { + "prob":0.99, + "token":"." + } + ] + } + ], + "mean_per_file":[ + { + "prob_tot_mean":0.71, + "perplex_tot_mean":1.52 + } + ], + "mean_per_phrase":[ + { + "id":"chatcmpl", + "prob_per_phrase":0.71, + "perplex_per_phrase":1.54 + } + ] + } + }""", ) metrics = model_completion_dataset.model_quality() assert isinstance(metrics, CompletionTextGenerationModelQuality) + assert metrics.tokens[0].probs[0].prob == 0.27 + assert metrics.tokens[0].probs[0].token == 'Sky' + assert metrics.mean_per_file[0].prob_tot_mean == 0.71 + assert metrics.mean_per_file[0].perplex_tot_mean == 1.52 + assert metrics.mean_per_phrase[0].prob_per_phrase == 0.71 + assert metrics.mean_per_phrase[0].perplex_per_phrase == 1.54 assert model_completion_dataset.status() == JobStatus.SUCCEEDED @responses.activate