Skip to content

Commit

Permalink
fix transformers special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Dec 6, 2024
1 parent 1f1a45b commit c1c7954
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
13 changes: 7 additions & 6 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens

special_tokens = []
# check if special tokens exist to circumvent error message
if embedding.tokenizer._bos_token:
if embedding.tokenizer.bos_token is not None:
special_tokens.append(embedding.tokenizer.bos_token)
if embedding.tokenizer._cls_token:
if embedding.tokenizer.cls_token is not None:
special_tokens.append(embedding.tokenizer.cls_token)
if embedding.tokenizer._sep_token:
if embedding.tokenizer.sep_token is not None:
special_tokens.append(embedding.tokenizer.sep_token)

# iterate over subtokens and reconstruct tokens
Expand Down Expand Up @@ -1354,9 +1354,10 @@ def from_params(cls, params):
def to_params(self):
config_dict = self.model.config.to_dict()

# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
config_dict.pop("_attn_implementation_autoset", None)
if hasattr(self.model.config, "_attn_implementation"):
# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
config_dict.pop("_attn_implementation_autoset", None)

super_params = super().to_params()

Expand Down
8 changes: 5 additions & 3 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,9 +718,11 @@ def __init__(
)

# transformer separator
self.separator = str(self.tars_embeddings.tokenizer.sep_token)
if self.tars_embeddings.tokenizer._bos_token:
self.separator += str(self.tars_embeddings.tokenizer.bos_token)
self.separator = (
self.tars_embeddings.tokenizer.sep_token if self.tars_embeddings.tokenizer.sep_token is not None else ""
)
if self.tars_embeddings.tokenizer.bos_token is not None:
self.separator += self.tars_embeddings.tokenizer.bos_token

self.prefix = prefix
self.num_negative_labels_to_sample = num_negative_labels_to_sample
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ tabulate>=0.8.10
torch>=1.5.0,!=1.8
tqdm>=4.63.0
transformer-smaller-training-vocab>=0.2.3
transformers[sentencepiece]>=4.18.0,<5.0.0
transformers[sentencepiece]>=4.25.0,<5.0.0
wikipedia-api>=0.5.7
bioc<3.0.0,>=2.0.0

0 comments on commit c1c7954

Please sign in to comment.