diff --git a/requirements.txt b/requirements.txt index 59db81e3..246421e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,3 @@ more-itertools pinyin>=0.4.0 jieba OpenHowNet -click==8.0.2 diff --git a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py index ce217d50..4e12b41f 100644 --- a/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py +++ b/textattack/transformations/word_swaps/chn_transformations/chinese_word_swap_masked.py @@ -3,7 +3,7 @@ ------------------------------------- """ -from transformers import pipeline +import torch from . import WordSwap @@ -13,8 +13,7 @@ class ChineseWordSwapMaskedLM(WordSwap): model.""" def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs): - from transformers import BertTokenizer, BertForMaskedLM - import torch + from transformers import BertForMaskedLM, BertTokenizer self.tt = BertTokenizer.from_pretrained(model) self.mm = BertForMaskedLM.from_pretrained(model)