Skip to content

Commit

Permalink
feat: add text generation metrics in the sdk module (#219)
Browse files Browse the repository at this point in the history
* add text generation metrics

* add get all completion models
  • Loading branch information
dtria91 authored Dec 18, 2024
1 parent 1af295c commit 2c713d5
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 3 deletions.
22 changes: 22 additions & 0 deletions sdk/radicalbit_platform_sdk/apis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,28 @@ def __callback(response: requests.Response) -> List[ModelCurrentDataset]:
func=__callback,
)

def get_completion_datasets(self) -> List[ModelCompletionDataset]:
def __callback(response: requests.Response) -> List[ModelCompletionDataset]:
try:
adapter = TypeAdapter(List[CompletionFileUpload])
completions = adapter.validate_python(response.json())

return [
ModelCompletionDataset(
self.__base_url, self.__uuid, self.__model_type, completion
)
for completion in completions
]
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e

return invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__uuid)}/completion/all',
valid_response_code=200,
func=__callback,
)

def load_reference_dataset(
self,
file_name: str,
Expand Down
31 changes: 30 additions & 1 deletion sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,34 @@ class CurrentRegressionModelQuality(ModelQuality):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class TokenProb(BaseModel):
prob: float
token: str


class TokenData(BaseModel):
id: str
probs: List[TokenProb]


class MeanPerFile(BaseModel):
prob_tot_mean: float
perplex_tot_mean: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class MeanPerPhrase(BaseModel):
id: str
prob_per_phrase: float
perplex_per_phrase: float

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class CompletionTextGenerationModelQuality(ModelQuality):
pass
tokens: List[TokenData]
mean_per_file: List[MeanPerFile]
mean_per_phrase: List[MeanPerPhrase]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
48 changes: 46 additions & 2 deletions sdk/tests/apis/model_completion_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,57 @@ def test_text_generation_model_quality_ok(self):
body="""{
"datetime": "something_not_used",
"jobStatus": "SUCCEEDED",
"modelQuality": {}
}""",
"modelQuality":
{
"tokens": [
{
"id":"chatcmpl",
"probs":[
{
"prob":0.27,
"token":"Sky"
},
{
"prob":0.89,
"token":" is"
},
{
"prob":0.70,
"token":" blue"
},
{
"prob":0.99,
"token":"."
}
]
}
],
"mean_per_file":[
{
"prob_tot_mean":0.71,
"perplex_tot_mean":1.52
}
],
"mean_per_phrase":[
{
"id":"chatcmpl",
"prob_per_phrase":0.71,
"perplex_per_phrase":1.54
}
]
}
}""",
)

metrics = model_completion_dataset.model_quality()

assert isinstance(metrics, CompletionTextGenerationModelQuality)
assert metrics.tokens[0].probs[0].prob == 0.27
assert metrics.tokens[0].probs[0].token == 'Sky'
assert metrics.mean_per_file[0].prob_tot_mean == 0.71
assert metrics.mean_per_file[0].perplex_tot_mean == 1.52
assert metrics.mean_per_phrase[0].prob_per_phrase == 0.71
assert metrics.mean_per_phrase[0].perplex_per_phrase == 1.54
assert model_completion_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down

0 comments on commit 2c713d5

Please sign in to comment.