-
Notifications
You must be signed in to change notification settings - Fork 0
/
synonym.py
339 lines (284 loc) · 10.6 KB
/
synonym.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import logging
import nltk
from nltk.corpus import wordnet
loggers = logging.getLogger(__name__)
class WordNet:
"""
A class to interact with the WordNet lexical database for finding synonyms and antonyms.
:param lang: The language for the WordNet database (default is "eng").
:type lang: str
:param is_synonym: A flag to indicate whether to find synonyms (True) or antonyms (False).
:type is_synonym: bool
"""
def __init__(self, lang="eng", is_synonym=True):
"""
Initializes the WordNet class with the specified language and synonym/antonym flag.
:param lang: The language for the WordNet database (default is "eng").
:type lang: str
:param is_synonym: A flag to indicate whether to find synonyms (True) or antonyms (False).
:type is_synonym: bool
"""
self.lang = lang
self.is_synonym = is_synonym
self.model = self.read()
def read(self):
"""
Loads the WordNet corpus, downloading it if necessary.
:return: The WordNet corpus reader.
:rtype: nltk.corpus.reader.wordnet.WordNetCorpusReader
"""
try:
wordnet.synsets("testing")
return wordnet
except LookupError:
nltk.download("wordnet")
nltk.download("omw-1.4")
return wordnet
def predict(self, word, pos=None):
"""
Finds synonyms or antonyms for a given word.
:param word: The word for which to find synonyms or antonyms.
:type word: str
:param pos: The part of speech tag (default is None).
:type pos: str, optional
:return: A list of synonyms or antonyms for the given word.
:rtype: list
"""
results = []
for synonym in self.model.synsets(word, pos=pos, lang=self.lang):
for lemma in synonym.lemmas(lang=self.lang):
if self.is_synonym:
results.append(lemma.name())
else:
for antonym in lemma.antonyms():
results.append(antonym.name())
return results
@classmethod
def pos_tag(cls, tokens):
"""
Tags parts of speech for a list of tokens.
:param tokens: A list of tokens to tag.
:type tokens: list
:return: A list of tuples where each tuple contains a token and its part of speech tag.
:rtype: list
"""
try:
results = nltk.pos_tag(tokens)
except LookupError:
nltk.download('averaged_perceptron_tagger')
nltk.download('averaged_perceptron_tagger_eng')
results = nltk.pos_tag(tokens)
return results
class PartOfSpeech:
"""
A class to handle part-of-speech (POS) tagging and mapping between POS tags and their constituents.
Attributes:
-----------
NOUN : str
Constant for noun POS.
VERB : str
Constant for verb POS.
ADJECTIVE : str
Constant for adjective POS.
ADVERB : str
Constant for adverb POS.
pos2con : dict
Dictionary mapping POS tags to their constituent tags.
con2pos : dict
Dictionary mapping constituent tags to their POS tags.
poses : list
List of all constituent tags.
Methods:
--------
pos2constituent(pos):
Maps a POS tag to its constituent tags.
constituent2pos(con):
Maps a constituent tag to its POS tags.
get_pos():
Returns a list of all constituent tags.
"""
NOUN = "noun"
VERB = "verb"
ADJECTIVE = "adjective"
ADVERB = "adverb"
pos2con = {
"n": ["NN", "NNS", "NNP", "NNPS"],
"v": ["VB", "VBD", "VBG", "VBN", "VBZ", "VBP"],
"a": ["JJ", "JJR", "JJS", "IN"],
"s": ["JJ", "JJR", "JJS", "IN"], # Adjective Satellite
"r": ["RB", "RBR", "RBS"],
}
con2pos = {}
poses = []
for key, values in pos2con.items():
poses.extend(values)
for value in values:
if value not in con2pos:
con2pos[value] = []
con2pos[value].append(key)
@staticmethod
def pos2constituent(pos):
"""
Maps a POS tag to its constituent tags.
:param pos: The POS tag.
:type pos: str
:return: A list of constituent tags for the given POS tag.
:rtype: list
"""
return PartOfSpeech.pos2con.get(pos, [])
@staticmethod
def constituent2pos(con):
"""
Maps a constituent tag to its POS tags.
:param con: The constituent tag.
:type con: str
:return: A list of POS tags for the given constituent tag.
:rtype: list
"""
return PartOfSpeech.con2pos.get(con, [])
@staticmethod
def get_pos():
"""
Returns a list of all constituent tags.
:return: A list of all constituent tags.
:rtype: list
"""
return PartOfSpeech.poses
def init_ppdb_model(dict_path, force_reload=False):
"""
Initializes the PPDB model from the given dictionary path.
:param dict_path: The path to the PPDB dictionary file.
:type dict_path: str
:param force_reload: A flag to indicate whether to force reload the model (default is False).
:type force_reload: bool
:return: The initialized PPDB model.
:rtype: nmw.Ppdb
"""
global PPDB_MODEL
model_name = os.path.basename(dict_path)
if model_name in PPDB_MODEL and not force_reload:
return PPDB_MODEL[model_name]
model = nmw.Ppdb(dict_path)
PPDB_MODEL[model_name] = model
return model
import random
from NLarge.utils.words import WordsUtil
class SynonymAugmenter():
"""
A class to perform synonym-based data augmentation using the WordNet lexical database.
Methods:
--------
__call__(data, aug_src="wordnet", model_path=None, lang="eng", aug_min=1, aug_max=10, aug_p=0.3, stopwords=None, tokenizer=None, reverse_tokenizer=None, stopwords_regex=None, force_reload=False, verbose=0):
Performs synonym-based data augmentation on the input data.
"""
def __init__(self) -> None:
loggers.info("SynonymAugmenter initialized")
def __call__(
self,
data,
aug_src="wordnet",
model_path=None,
lang="eng",
aug_min=1,
aug_max=10,
aug_p=0.3,
stopwords=None,
tokenizer=None,
reverse_tokenizer=None,
stopwords_regex=None,
force_reload=False,
verbose=0,
):
"""
Performs synonym-based data augmentation on the input data.
:param data: The input text data to be augmented.
:type data: str
:param aug_src: The source for augmentation (default is "wordnet").
:type aug_src: str
:param model_path: The path to the model (not used in this implementation).
:type model_path: str, optional
:param lang: The language for the WordNet database (default is "eng").
:type lang: str
:param aug_min: The minimum number of words to augment (default is 1).
:type aug_min: int
:param aug_max: The maximum number of words to augment (default is 10).
:type aug_max: int
:param aug_p: The probability of a word being augmented (default is 0.3).
:type aug_p: float
:param stopwords: A list of stopwords to exclude from augmentation.
:type stopwords: list, optional
:param tokenizer: A function to tokenize the input text (default is str.split).
:type tokenizer: function, optional
:param reverse_tokenizer: A function to detokenize the augmented text (default is " ".join).
:type reverse_tokenizer: function, optional
:param stopwords_regex: A regex pattern to match stopwords (not used in this implementation).
:type stopwords_regex: str, optional
:param force_reload: A flag to indicate whether to force reload the model (default is False).
:type force_reload: bool
:param verbose: The verbosity level (default is 0).
:type verbose: int
:return: The augmented text.
:rtype: str
"""
if not data or not data.strip():
return data
model = WordNet(lang=lang) if aug_src == "wordnet" else None
if model is None:
raise ValueError("currently, aug_src can only be `wordnet`.")
change_seq = 0
tokenizer = tokenizer or str.split
reverse_tokenizer = reverse_tokenizer or " ".join
doc = WordsUtil(data, tokenizer(data))
original_tokens = doc.get_original_tokens()
pos = model.pos_tag(original_tokens)
stopwords = stopwords or []
def skip_aug(token_idxes, tokens):
results = []
for token_idx in token_idxes:
if tokens[token_idx][1] in ["DT"]:
continue
word_poses = PartOfSpeech.constituent2pos(tokens[token_idx][1])
if aug_src == "ppdb" and not word_poses:
continue
if word_poses and not any(
model.predict(tokens[token_idx][0], pos=pos)
for pos in word_poses
):
continue
results.append(token_idx)
return results
def _get_aug_idxes(tokens):
aug_cnt = (
min(len(tokens), int(len(tokens) * aug_p)) if aug_p else aug_max
)
word_idxes = [i for i in range(len(tokens)) if i not in stopwords]
word_idxes = skip_aug(word_idxes, tokens)
return random.sample(word_idxes, aug_cnt) if word_idxes else []
aug_idxes = _get_aug_idxes(pos)
if not aug_idxes:
return data
for aug_idx in aug_idxes:
original_token = original_tokens[aug_idx]
word_poses = PartOfSpeech.constituent2pos(pos[aug_idx][1])
candidates = sum(
(
model.predict(pos[aug_idx][0], pos=word_pos)
for word_pos in word_poses
),
[],
)
candidates = [
c for c in candidates if c.lower() != original_token.lower()
]
if candidates:
substitute_token = random.choice(candidates).lower()
if aug_idx == 0:
substitute_token = substitute_token.capitalize()
change_seq += 1
doc.add_change_log(
aug_idx,
new_token=substitute_token,
action="substitute",
change_seq=change_seq,
)
return reverse_tokenizer(doc.get_augmented_tokens())