Skip to content

Commit

Permalink
Merge pull request #3336 from flairNLP/fix_import_error
Browse files Browse the repository at this point in the history
fix import error
  • Loading branch information
alanakbik authored Oct 12, 2023
2 parents 41b2ad4 + a290eb4 commit 42ea3f6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
2 changes: 0 additions & 2 deletions flair/trainers/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt
from .functional.amp import AmpPlugin
from .functional.anneal_on_plateau import AnnealingPlugin
from .functional.checkpoints import CheckpointPlugin
from .functional.linear_scheduler import LinearSchedulerPlugin
Expand All @@ -11,7 +10,6 @@
from .metric_records import MetricName, MetricRecord

__all__ = [
"AmpPlugin",
"AnnealingPlugin",
"CheckpointPlugin",
"LinearSchedulerPlugin",
Expand Down
3 changes: 1 addition & 2 deletions flair/trainers/plugins/functional/anneal_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self.anneal_factor = anneal_factor
self.patience = patience
self.initial_extra_patience = initial_extra_patience
self.scheduler: AnnealOnPlateau

def store_learning_rate(self):
optimizer = self.trainer.optimizer
Expand Down Expand Up @@ -117,6 +118,4 @@ def get_state(self) -> Dict[str, Any]:
"patience": self.patience,
"initial_extra_patience": self.initial_extra_patience,
"anneal_with_restarts": self.anneal_with_restarts,
"bad_epochs": self.scheduler.num_bad_epochs,
"current_best": self.scheduler.best,
}
27 changes: 27 additions & 0 deletions test_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from flair.data import Sentence
from flair.embeddings import TransformerWordEmbeddings

phrase_0 = Sentence("a uui")
embeddings_a = TransformerWordEmbeddings(
'roberta-base',
use_context=True,
use_context_separator=False,
)
ebd_a = embeddings_a.embed(phrase_0)

phrase_1 = Sentence("a uui")
embeddings_b = TransformerWordEmbeddings(
'roberta-base',
use_context=True,
use_context_separator=False,
)
ebd_b = embeddings_b.embed(phrase_1)
ebd_b = [phrase_1]
ebd_a = [phrase_0]

print(
"token run 0:", ebd_a[-1][-1], "\n",
"embedding end run 0:", ebd_a[-1][-1].embedding.tolist()[-2:], "\n",
"token run 1: ", ebd_b[-1][-1], "\n",
"embedding end run 1:", ebd_b[-1][-1].embedding.tolist()[-2:]
)

0 comments on commit 42ea3f6

Please sign in to comment.