Skip to content

Commit

Permalink
Merge pull request #763 from k-ivey/sbert
Browse files Browse the repository at this point in the history
Rename BERT constraint to SBERT
  • Loading branch information
jxmorris12 authored Mar 5, 2024
2 parents 62fad01 + 37e5ae2 commit f98fbe6
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ textattack.constraints.semantics.sentence\_encoders package
.. toctree::
:maxdepth: 6

textattack.constraints.semantics.sentence_encoders.bert
textattack.constraints.semantics.sentence_encoders.sentence_bert
textattack.constraints.semantics.sentence_encoders.infer_sent
textattack.constraints.semantics.sentence_encoders.universal_sentence_encoder

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
textattack.constraints.semantics.sentence\_encoders.bert package
================================================================

.. automodule:: textattack.constraints.semantics.sentence_encoders.bert
.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert
:members:
:undoc-members:
:show-inheritance:




.. automodule:: textattack.constraints.semantics.sentence_encoders.bert.bert
.. automodule:: textattack.constraints.semantics.sentence_encoders.sentence_bert.sbert
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
# Semantics constraints
#
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
"sbert": "textattack.constraints.semantics.sentence_encoders.SBERT",
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_recipes/a2t_yoo_2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.constraints.semantics.sentence_encoders import SBERT
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
Expand Down Expand Up @@ -49,7 +49,7 @@ def build(model_wrapper, mlm=False):
constraints.append(input_column_modification)
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
sent_encoder = BERT(
sent_encoder = SBERT(
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
)
constraints.append(sent_encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .sentence_encoder import SentenceEncoder

from .bert import BERT
from .sentence_bert import SBERT
from .infer_sent import InferSent
from .thought_vector import ThoughtVector
from .universal_sentence_encoder import (
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
sBERT
^^^^^^^
"""

from .sbert import SBERT
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)


class BERT(SentenceEncoder):
class SBERT(SentenceEncoder):
"""Constraint using similarity between sentence encodings of x and x_adv
where the text embeddings are created using BERT, trained on NLI data, and
fine- tuned on the STS benchmark dataset.
Expand Down
4 changes: 2 additions & 2 deletions textattack/metrics/quality_metrics/sentence_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.constraints.semantics.sentence_encoders import SBERT
from textattack.metrics import Metric


class SBERTMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.use_obj = SBERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}
Expand Down

0 comments on commit f98fbe6

Please sign in to comment.