From 967189d7b14e1f7ef6f406aece96e667fd08231c Mon Sep 17 00:00:00 2001 From: ashkankzme Date: Tue, 6 Aug 2024 08:46:57 -0700 Subject: [PATCH] updating tests + updating code (not raising errors when all returned labels are invalid) --- lib/model/classycat_classify.py | 9 +++------ test/lib/model/test_classycat.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/lib/model/classycat_classify.py b/lib/model/classycat_classify.py index 89a1da6..d623679 100644 --- a/lib/model/classycat_classify.py +++ b/lib/model/classycat_classify.py @@ -132,12 +132,9 @@ def classify_and_store_results(self, schema_id, items): for result in final_results: result['labels'] = [label for label in result['labels'] if label in permitted_labels] - if len(final_results) == 0: - logger.info(f"The returned classifications did not produce labels from the schema: {items}") - raise Exception(f"No items were classified successfully") - - results_file_id = str(uuid.uuid4()) - upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results)) + if not all([len(result['labels']) == 0 for result in final_results]): + results_file_id = str(uuid.uuid4()) + upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results)) return final_results diff --git a/test/lib/model/test_classycat.py b/test/lib/model/test_classycat.py index f16426b..46d4b44 100644 --- a/test/lib/model/test_classycat.py +++ b/test/lib/model/test_classycat.py @@ -819,7 +819,7 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u ] } ) - openrouter_classify_mock.return_value = "\nPolitics;Communalism\nPolitico;Communism\n" + openrouter_classify_mock.return_value = "\nPolitics;Communalism\nPolitico;Communism\nPolitics;Communism\n" classify_input = { "model_name": "classycat__Model", "body": { @@ -835,6 +835,10 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u { "id": "12", "text": "modi and bjp are amazing politicians" + }, + { + "id": "13", + "text": "modi is an amazing politician" } ] }, @@ -845,7 +849,10 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u result = self.classycat_model.process(classify_message) self.assertEqual("success", result.responseMessage) - self.assertEqual(1, len(result.classification_results)) + self.assertEqual(3, len(result.classification_results)) + self.assertListEqual(["Politics", "Communalism"], result.classification_results[0]['labels']) + self.assertListEqual([], result.classification_results[1]['labels']) + self.assertListEqual(["Politics"], result.classification_results[2]['labels']) @patch('lib.model.classycat_classify.OpenRouterClient.classify') @patch('lib.model.classycat_classify.load_file_from_s3') @@ -987,7 +994,10 @@ def test_classify_fail_all_out_of_schema_labels(self, file_exists_in_s3_mock, up classify_message = schemas.parse_message(classify_input) result = self.classycat_model.process(classify_message) - self.assertIn("Error classifying items: No items were classified successfully", result.responseMessage) + self.assertEqual("success", result.responseMessage) + self.assertEqual(2, len(result.classification_results)) + self.assertListEqual([], result.classification_results[0]['labels']) + self.assertListEqual([], result.classification_results[1]['labels']) if __name__ == '__main__': unittest.main()