diff --git a/lib/model/yake_keywords.py b/lib/model/yake_keywords.py index 8fc8948..915e8fe 100644 --- a/lib/model/yake_keywords.py +++ b/lib/model/yake_keywords.py @@ -7,8 +7,35 @@ from lib import schemas import yake +import cld3 class Model(Model): + + def keep_largest_overlapped_keywords(self, keywords): + cleaned_keywords = [] + for i in range(len(keywords)): + keep_keyword = True + for j in range(len(keywords)): + current_keyword = keywords[i][0] + other_keyword = keywords[j][0] + if len(other_keyword) > len(current_keyword): + if other_keyword.find(current_keyword + " ") >= 0 or other_keyword.find(" " + current_keyword) >= 0: + keep_keyword = False + break + if keep_keyword: + cleaned_keywords.append(keywords[i]) + return cleaned_keywords + + def normalize_special_characters(self, text): + replacement = {"`": "'", + "‘": "'", + "’": "'", + "“": "\"", + "”": "\""} + for k, v in replacement.items(): + text = text.replace(k, v) + return text + def run_yake(self, text: str, language: str, max_ngram_size: int, @@ -26,15 +53,25 @@ def run_yake(self, text: str, :param num_of_keywords: int :returns: str """ + ### if language is set to "auto", auto-detect it. + if language == 'auto': + language = cld3.get_language(text).language + ### normalize special characters + text = self.normalize_special_characters(text) + ### extract keywords custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=window_size, top=num_of_keywords, features=None) - return {"keywords": custom_kw_extractor.extract_keywords(text)} + + ### Keep the longest keyword of if there is an overlap between two keywords. + keywords = custom_kw_extractor.extract_keywords(text) + keywords = self.keep_largest_overlapped_keywords(keywords) + return {"keywords": keywords} def get_params(self, message: schemas.Message) -> dict: params = { "text": message.body.text, - "language": message.body.parameters.get("language", "en"), + "language": message.body.parameters.get("language", "auto"), "max_ngram_size": message.body.parameters.get("max_ngram_size", 3), "deduplication_threshold": message.body.parameters.get("deduplication_threshold", 0.25), "deduplication_algo": message.body.parameters.get("deduplication_algo", 'seqm'), diff --git a/requirements.txt b/requirements.txt index aed95a1..8beeca6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ numpy==1.26.4 protobuf==3.20.2 openai==1.35.6 anthropic==0.31.1 +pycld3==0.22 \ No newline at end of file diff --git a/test/lib/model/test_yake_keywords.py b/test/lib/model/test_yake_keywords.py index 47ebef5..9c36693 100644 --- a/test/lib/model/test_yake_keywords.py +++ b/test/lib/model/test_yake_keywords.py @@ -43,6 +43,16 @@ def test_run_yake_real(self): results = self.yake_model.run_yake(**self.yake_model.get_params(message)) self.assertEqual(results, {"keywords": [('love Meedan', 0.0013670273525686505)]}) + def test_keep_largest_overlapped_keywords(self): + keywords_test = [('Alegre', 0),('Alegre', 0),('Timpani', 0), ('Presto Timpani', 0), ('AlegreAlegre', 0), ('Alegre Alegre', 0), ("Presto", 0)] + expected = [('Presto Timpani', 0), ('AlegreAlegre', 0), ('Alegre Alegre', 0)] + self.assertEqual(self.yake_model.keep_largest_overlapped_keywords(keywords_test), expected) + + def test_normalize_special_characters(self): + text = "`‘’“”" + expected = "'''\"\"" + self.assertEqual(self.yake_model.normalize_special_characters(text), expected) + def test_get_params_with_defaults(self): message = schemas.parse_message({ "body": { @@ -51,7 +61,7 @@ def test_get_params_with_defaults(self): }, "model_name": "yake_keywords__Model" }) - expected = {'text': 'Some Text', 'language': "en", 'max_ngram_size': 3, 'deduplication_threshold': 0.25, 'deduplication_algo': 'seqm', 'window_size': 0, 'num_of_keywords': 10} + expected = {'text': 'Some Text', 'language': "auto", 'max_ngram_size': 3, 'deduplication_threshold': 0.25, 'deduplication_algo': 'seqm', 'window_size': 0, 'num_of_keywords': 10} self.assertEqual(self.yake_model.get_params(message), expected) def test_get_params_with_specifics(self):