Skip to content

Commit

Permalink
updating tests + updating code (not raising errors when all returned …
Browse files Browse the repository at this point in the history
…labels are invalid)
  • Loading branch information
ashkankzme committed Aug 6, 2024
1 parent 870b491 commit 967189d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
9 changes: 3 additions & 6 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,9 @@ def classify_and_store_results(self, schema_id, items):
for result in final_results:
result['labels'] = [label for label in result['labels'] if label in permitted_labels]

if len(final_results) == 0:
logger.info(f"The returned classifications did not produce labels from the schema: {items}")
raise Exception(f"No items were classified successfully")

results_file_id = str(uuid.uuid4())
upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results))
if not all([len(result['labels']) == 0 for result in final_results]):
results_file_id = str(uuid.uuid4())
upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results))

return final_results

Expand Down
16 changes: 13 additions & 3 deletions test/lib/model/test_classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
]
}
)
openrouter_classify_mock.return_value = "<OUTPUT>\n<CATEGORIES_0>Politics;Communalism</CATEGORIES_0>\n<CATEGORIES_1>Politico;Communism</CATEGORIES_1>\n</OUTPUT>"
openrouter_classify_mock.return_value = "<OUTPUT>\n<CATEGORIES_0>Politics;Communalism</CATEGORIES_0>\n<CATEGORIES_1>Politico;Communism</CATEGORIES_1>\n<CATEGORIES_2>Politics;Communism</CATEGORIES_2>\n</OUTPUT>"
classify_input = {
"model_name": "classycat__Model",
"body": {
Expand All @@ -835,6 +835,10 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
{
"id": "12",
"text": "modi and bjp are amazing politicians"
},
{
"id": "13",
"text": "modi is an amazing politician"
}
]
},
Expand All @@ -845,7 +849,10 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
result = self.classycat_model.process(classify_message)

self.assertEqual("success", result.responseMessage)
self.assertEqual(1, len(result.classification_results))
self.assertEqual(3, len(result.classification_results))
self.assertListEqual(["Politics", "Communalism"], result.classification_results[0]['labels'])
self.assertListEqual([], result.classification_results[1]['labels'])
self.assertListEqual(["Politics"], result.classification_results[2]['labels'])

@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
Expand Down Expand Up @@ -987,7 +994,10 @@ def test_classify_fail_all_out_of_schema_labels(self, file_exists_in_s3_mock, up
classify_message = schemas.parse_message(classify_input)
result = self.classycat_model.process(classify_message)

self.assertIn("Error classifying items: No items were classified successfully", result.responseMessage)
self.assertEqual("success", result.responseMessage)
self.assertEqual(2, len(result.classification_results))
self.assertListEqual([], result.classification_results[0]['labels'])
self.assertListEqual([], result.classification_results[1]['labels'])

if __name__ == '__main__':
unittest.main()

0 comments on commit 967189d

Please sign in to comment.