diff --git a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst index 0cace50f4..9233e333c 100644 --- a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst +++ b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst @@ -11,7 +11,7 @@ textattack.constraints.semantics.sentence\_encoders package .. toctree:: :maxdepth: 6 - textattack.constraints.semantics.sentence_encoders.bert + textattack.constraints.semantics.sentence_encoders.sentence_bert textattack.constraints.semantics.sentence_encoders.infer_sent textattack.constraints.semantics.sentence_encoders.universal_sentence_encoder diff --git a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.bert.rst b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.sentence_bert.rst similarity index 90% rename from docs/apidoc/textattack.constraints.semantics.sentence_encoders.bert.rst rename to docs/apidoc/textattack.constraints.semantics.sentence_encoders.sentence_bert.rst index 5dd389281..2e9094e68 100644 --- a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.bert.rst +++ b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.sentence_bert.rst @@ -1,7 +1,7 @@ textattack.constraints.semantics.sentence\_encoders.bert package ================================================================ -.. automodule:: textattack.constraints.semantics.sentence_encoders.bert +.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert :members: :undoc-members: :show-inheritance: @@ -9,7 +9,7 @@ textattack.constraints.semantics.sentence\_encoders.bert package -.. automodule:: textattack.constraints.semantics.sentence_encoders.bert.bert +.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert.sbert :members: :undoc-members: :show-inheritance: diff --git a/textattack/attack_args.py b/textattack/attack_args.py index 38a7bd25d..b99f6dc58 100644 --- a/textattack/attack_args.py +++ b/textattack/attack_args.py @@ -67,7 +67,7 @@ # Semantics constraints # "embedding": "textattack.constraints.semantics.WordEmbeddingDistance", - "bert": "textattack.constraints.semantics.sentence_encoders.BERT", + "sbert": "textattack.constraints.semantics.sentence_encoders.SBERT", "infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent", "thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector", "use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder", diff --git a/textattack/attack_recipes/a2t_yoo_2021.py b/textattack/attack_recipes/a2t_yoo_2021.py index faf24e95f..bf6202961 100644 --- a/textattack/attack_recipes/a2t_yoo_2021.py +++ b/textattack/attack_recipes/a2t_yoo_2021.py @@ -13,7 +13,7 @@ StopwordModification, ) from textattack.constraints.semantics import WordEmbeddingDistance -from textattack.constraints.semantics.sentence_encoders import BERT +from textattack.constraints.semantics.sentence_encoders import SBERT from textattack.goal_functions import UntargetedClassification from textattack.search_methods import GreedyWordSwapWIR from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM @@ -49,7 +49,7 @@ def build(model_wrapper, mlm=False): constraints.append(input_column_modification) constraints.append(PartOfSpeech(allow_verb_noun_swap=False)) constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4)) - sent_encoder = BERT( + sent_encoder = SBERT( model_name="stsb-distilbert-base", threshold=0.9, metric="cosine" ) constraints.append(sent_encoder) diff --git a/textattack/constraints/semantics/sentence_encoders/__init__.py b/textattack/constraints/semantics/sentence_encoders/__init__.py index 6c8e88b38..d60166b4e 100644 --- a/textattack/constraints/semantics/sentence_encoders/__init__.py +++ b/textattack/constraints/semantics/sentence_encoders/__init__.py @@ -6,7 +6,7 @@ from .sentence_encoder import SentenceEncoder -from .bert import BERT +from .sentence_bert import SBERT from .infer_sent import InferSent from .thought_vector import ThoughtVector from .universal_sentence_encoder import ( diff --git a/textattack/constraints/semantics/sentence_encoders/bert/__init__.py b/textattack/constraints/semantics/sentence_encoders/bert/__init__.py deleted file mode 100644 index e0f312aa3..000000000 --- a/textattack/constraints/semantics/sentence_encoders/bert/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -BERT -^^^^^^^ -""" - -from .bert import BERT diff --git a/textattack/constraints/semantics/sentence_encoders/sentence_bert/__init__.py b/textattack/constraints/semantics/sentence_encoders/sentence_bert/__init__.py new file mode 100644 index 000000000..ca45a4756 --- /dev/null +++ b/textattack/constraints/semantics/sentence_encoders/sentence_bert/__init__.py @@ -0,0 +1,6 @@ +""" +sBERT +^^^^^^^ +""" + +from .sbert import SBERT diff --git a/textattack/constraints/semantics/sentence_encoders/bert/bert.py b/textattack/constraints/semantics/sentence_encoders/sentence_bert/sbert.py similarity index 97% rename from textattack/constraints/semantics/sentence_encoders/bert/bert.py rename to textattack/constraints/semantics/sentence_encoders/sentence_bert/sbert.py index 0972411b0..4f1b5d713 100644 --- a/textattack/constraints/semantics/sentence_encoders/bert/bert.py +++ b/textattack/constraints/semantics/sentence_encoders/sentence_bert/sbert.py @@ -11,7 +11,7 @@ ) -class BERT(SentenceEncoder): +class SBERT(SentenceEncoder): """Constraint using similarity between sentence encodings of x and x_adv where the text embeddings are created using BERT, trained on NLI data, and fine- tuned on the STS benchmark dataset. diff --git a/textattack/metrics/quality_metrics/sentence_bert.py b/textattack/metrics/quality_metrics/sentence_bert.py index f96660af6..380365ed3 100644 --- a/textattack/metrics/quality_metrics/sentence_bert.py +++ b/textattack/metrics/quality_metrics/sentence_bert.py @@ -7,13 +7,13 @@ """ from textattack.attack_results import FailedAttackResult, SkippedAttackResult -from textattack.constraints.semantics.sentence_encoders import BERT +from textattack.constraints.semantics.sentence_encoders import SBERT from textattack.metrics import Metric class SBERTMetric(Metric): def __init__(self, **kwargs): - self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine") + self.use_obj = SBERT(model_name="all-MiniLM-L6-v2", metric="cosine") self.original_candidates = [] self.successful_candidates = [] self.all_metrics = {}