Skip to content

Commit

Permalink
Merge branch 'master' into mauryaland-crf_score
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Fuchs committed Oct 12, 2023
2 parents 27a952a + 42ea3f6 commit b554291
Show file tree
Hide file tree
Showing 23 changed files with 127 additions and 74 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install Torch cpu
run: pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install Flair dependencies
run: pip install -e .
- name: Install unittest dependencies
Expand All @@ -31,4 +33,4 @@ jobs:
- name: Run tests
run: |
python -c 'import flair'
pytest --runintegration --durations=0 -vv
pytest --runintegration -vv
2 changes: 1 addition & 1 deletion .github/workflows/issues.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on: issue_comment
jobs:
issue_commented:
name: Issue comment
if: ${{ !github.event.issue.pull_request && github.event.issue.author == github.even.issue_comment.author }}
if: ${{ github.event.issue.pull_request && github.event.issue.author == github.even.issue_comment.author }}
runs-on: ubuntu-latest
steps:
- uses: actions-ecosystem/action-remove-labels@v1
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/publish-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ env.python-version }}
- name: Install Torch cpu
run: pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install Flair dependencies
run: pip install -e .
- name: Install unittest dependencies
Expand Down
1 change: 0 additions & 1 deletion flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def __init__(
log.warning('ATTENTION! The library "pymongo" is not installed!')
log.warning('To use MongoDataset, please first install with "pip install pymongo"')
log.warning("-" * 100)
pass

self.in_memory = in_memory
self.tokenizer = tokenizer
Expand Down
1 change: 0 additions & 1 deletion flair/datasets/biomedical.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def filter_and_map_entities(
new_entities.append(new_entity)
else:
logging.debug(f"Skip entity type {entity.type}")
pass
mapped_entities_per_document[id] = new_entities

return InternalBioNerDataset(documents=dataset.documents, entities_per_document=mapped_entities_per_document)
Expand Down
1 change: 0 additions & 1 deletion flair/embeddings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def __init__(
log.warning('ATTENTION! The library "sentence-transformers" is not installed!')
log.warning('To use Sentence Transformers, please first install with "pip install sentence-transformers"')
log.warning("-" * 100)
pass

self.model_name = model
self.model = SentenceTransformer(
Expand Down
3 changes: 1 addition & 2 deletions flair/embeddings/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def embedding_type(self) -> str:

def to_params(self) -> Dict[str, Any]:
# legacy pickle-like saving for image embeddings, as implementation details are not obvious
return self.__getstate__() # type: ignore[operator]
return self.__getstate__()

@classmethod
def from_params(cls, params: Dict[str, Any]) -> "Embeddings":
Expand Down Expand Up @@ -104,7 +104,6 @@ def __init__(self, name, pretrained=True, transforms=None) -> None:
log.warning('ATTENTION! The library "torchvision" is not installed!')
log.warning('To use convnets pretraned on ImageNet, please first install with "pip install torchvision"')
log.warning("-" * 100)
pass

model_info = {
"resnet50": (torchvision.models.resnet50, lambda x: list(x)[:-1], 2048),
Expand Down
1 change: 0 additions & 1 deletion flair/embeddings/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
log.warning('ATTENTION! The library "allennlp" is not installed!')
log.warning('To use ELMoEmbeddings, please first install with "pip install allennlp==0.9.0"')
log.warning("-" * 100)
pass

assert embedding_mode in ["all", "top", "average"]

Expand Down
2 changes: 1 addition & 1 deletion flair/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, _LRScheduler
from torch.optim.optimizer import required # type: ignore[attr-defined]
from torch.optim.optimizer import required

log = logging.getLogger("flair")

Expand Down
2 changes: 0 additions & 2 deletions flair/trainers/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt
from .functional.amp import AmpPlugin
from .functional.anneal_on_plateau import AnnealingPlugin
from .functional.checkpoints import CheckpointPlugin
from .functional.linear_scheduler import LinearSchedulerPlugin
Expand All @@ -11,7 +10,6 @@
from .metric_records import MetricName, MetricRecord

__all__ = [
"AmpPlugin",
"AnnealingPlugin",
"CheckpointPlugin",
"LinearSchedulerPlugin",
Expand Down
4 changes: 4 additions & 0 deletions flair/trainers/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import count
from queue import Queue
from typing import (
Any,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -259,6 +260,9 @@ def pluggable(self) -> Optional[Pluggable]:
def __str__(self) -> str:
return self.__class__.__name__

def get_state(self) -> Dict[str, Any]:
return {"__cls__": f"{self.__module__}.{self.__class__.__name__}"}


class TrainerPlugin(BasePlugin):
@property
Expand Down
52 changes: 0 additions & 52 deletions flair/trainers/plugins/functional/amp.py

This file was deleted.

13 changes: 13 additions & 0 deletions flair/trainers/plugins/functional/anneal_on_plateau.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin, TrainingInterrupt
from flair.trainers.plugins.metric_records import MetricRecord
Expand Down Expand Up @@ -34,6 +35,7 @@ def __init__(
self.anneal_factor = anneal_factor
self.patience = patience
self.initial_extra_patience = initial_extra_patience
self.scheduler: AnnealOnPlateau

def store_learning_rate(self):
optimizer = self.trainer.optimizer
Expand Down Expand Up @@ -106,3 +108,14 @@ def __str__(self) -> str:
f"anneal_factor: '{self.anneal_factor}', "
f"min_learning_rate: '{self.min_learning_rate}'"
)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"min_learning_rate": self.min_learning_rate,
"anneal_factor": self.anneal_factor,
"patience": self.patience,
"initial_extra_patience": self.initial_extra_patience,
"anneal_with_restarts": self.anneal_with_restarts,
}
9 changes: 9 additions & 0 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -27,3 +28,11 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"save_model_each_k_epochs": self.save_model_each_k_epochs,
"save_optimizer_state": self.save_optimizer_state,
}
15 changes: 11 additions & 4 deletions flair/trainers/plugins/functional/linear_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

from flair.optim import LinearSchedulerWithWarmup
from flair.trainers.plugins.base import TrainerPlugin
Expand All @@ -9,7 +10,7 @@
class LinearSchedulerPlugin(TrainerPlugin):
"""Plugin for LinearSchedulerWithWarmup."""

def __init__(self, warmup_fraction: float, **kwargs) -> None:
def __init__(self, warmup_fraction: float) -> None:
super().__init__()

self.warmup_fraction = warmup_fraction
Expand All @@ -29,7 +30,7 @@ def after_setup(
dataset_size,
mini_batch_size,
max_epochs,
**kw,
**kwargs,
):
"""Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers."""
# calculate warmup steps
Expand All @@ -44,13 +45,13 @@ def after_setup(
self.store_learning_rate()

@TrainerPlugin.hook
def before_training_epoch(self, **kw):
def before_training_epoch(self, **kwargs):
"""Load state for anneal_with_restarts, batch_growth_annealing, logic for early stopping."""
self.store_learning_rate()
self.previous_learning_rate = self.current_learning_rate

@TrainerPlugin.hook
def after_training_batch(self, optimizer_was_run: bool, **kw):
def after_training_batch(self, optimizer_was_run: bool, **kwargs):
"""Do the scheduler step if one-cycle or linear decay."""
# skip if no optimization has happened.
if not optimizer_was_run:
Expand All @@ -60,3 +61,9 @@ def after_training_batch(self, optimizer_was_run: bool, **kw):

def __str__(self) -> str:
return f"LinearScheduler | warmup_fraction: '{self.warmup_fraction}'"

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"warmup_fraction": self.warmup_fraction,
}
9 changes: 9 additions & 0 deletions flair/trainers/plugins/functional/weight_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import WeightExtractor

Expand All @@ -7,6 +9,7 @@ class WeightExtractorPlugin(TrainerPlugin):

def __init__(self, base_path) -> None:
super().__init__()
self.base_path = base_path
self.weight_extractor = WeightExtractor(base_path)

@TrainerPlugin.hook
Expand All @@ -17,3 +20,9 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw):

if (iteration + 1) % modulo == 0:
self.weight_extractor.extract_weights(self.model.state_dict(), iteration)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
}
6 changes: 5 additions & 1 deletion flair/trainers/plugins/loggers/log_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import add_file_handler
Expand All @@ -12,10 +13,13 @@ class LogFilePlugin(TrainerPlugin):

def __init__(self, base_path) -> None:
super().__init__()

self.base_path = base_path
self.log_handler = add_file_handler(log, Path(base_path) / "training.log")

@TrainerPlugin.hook("_training_exception", "after_training")
def close_file_handler(self, **kw):
self.log_handler.close()
log.removeHandler(self.log_handler)

def get_state(self) -> Dict[str, Any]:
return {**super().get_state(), "base_path": str(self.base_path)}
11 changes: 9 additions & 2 deletions flair/trainers/plugins/loggers/loss_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

from flair.trainers.plugins.base import TrainerPlugin
from flair.trainers.plugins.metric_records import MetricName
Expand All @@ -15,9 +15,9 @@ def __init__(
super().__init__()

self.first_epoch = epoch + 1

# prepare loss logging file and set up header
self.loss_txt = init_output_file(base_path, "loss.tsv")
self.base_path = base_path

# set up all metrics to collect
self.metrics_to_collect = metrics_to_collect
Expand Down Expand Up @@ -58,6 +58,13 @@ def __init__(
# initialize the first log line
self.current_row: Optional[Dict[MetricName, str]] = None

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"metrics_to_collect": self.metrics_to_collect,
}

@TrainerPlugin.hook
def before_training_epoch(self, epoch, **kw):
"""Get the current epoch for loss file logging."""
Expand Down
8 changes: 7 additions & 1 deletion flair/trainers/plugins/loggers/metric_history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Mapping
from typing import Any, Dict, Mapping

from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -32,3 +32,9 @@ def metric_recorded(self, record):
def after_training(self, **kw):
"""Returns metric history."""
self.trainer.return_values.update(self.metric_history)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"metrics_to_collect": dict(self.metrics_to_collect),
}
10 changes: 10 additions & 0 deletions flair/trainers/plugins/loggers/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import log_line
Expand All @@ -22,6 +23,7 @@ def __init__(self, log_dir=None, comment="", tracked_metrics=()) -> None:
super().__init__()
self.comment = comment
self.tracked_metrics = tracked_metrics
self.log_dir = log_dir

try:
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -56,3 +58,11 @@ def _training_finally(self, **kw):
"""Closes the writer."""
assert self.writer is not None
self.writer.close()

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"log_dir": str(self.log_dir) if self.log_dir is not None else None,
"comment": self.comment,
"tracked_metrics": self.tracked_metrics,
}
Loading

0 comments on commit b554291

Please sign in to comment.