Skip to content

Commit

Permalink
Merge pull request #3325 from flairNLP/3312-bug-loading-model-with-ow…
Browse files Browse the repository at this point in the history
…n-trainerplugin

don't pickle classes & plugins in modelcard
  • Loading branch information
alanakbik authored Oct 10, 2023
2 parents 08aae3c + 7679c1e commit 41b2ad4
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 63 deletions.
4 changes: 4 additions & 0 deletions flair/trainers/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import count
from queue import Queue
from typing import (
Any,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -259,6 +260,9 @@ def pluggable(self) -> Optional[Pluggable]:
def __str__(self) -> str:
return self.__class__.__name__

def get_state(self) -> Dict[str, Any]:
return {"__cls__": f"{self.__module__}.{self.__class__.__name__}"}


class TrainerPlugin(BasePlugin):
@property
Expand Down
52 changes: 0 additions & 52 deletions flair/trainers/plugins/functional/amp.py

This file was deleted.

14 changes: 14 additions & 0 deletions flair/trainers/plugins/functional/anneal_on_plateau.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin, TrainingInterrupt
from flair.trainers.plugins.metric_records import MetricRecord
Expand Down Expand Up @@ -106,3 +107,16 @@ def __str__(self) -> str:
f"anneal_factor: '{self.anneal_factor}', "
f"min_learning_rate: '{self.min_learning_rate}'"
)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"min_learning_rate": self.min_learning_rate,
"anneal_factor": self.anneal_factor,
"patience": self.patience,
"initial_extra_patience": self.initial_extra_patience,
"anneal_with_restarts": self.anneal_with_restarts,
"bad_epochs": self.scheduler.num_bad_epochs,
"current_best": self.scheduler.best,
}
9 changes: 9 additions & 0 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -27,3 +28,11 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"save_model_each_k_epochs": self.save_model_each_k_epochs,
"save_optimizer_state": self.save_optimizer_state,
}
15 changes: 11 additions & 4 deletions flair/trainers/plugins/functional/linear_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

from flair.optim import LinearSchedulerWithWarmup
from flair.trainers.plugins.base import TrainerPlugin
Expand All @@ -9,7 +10,7 @@
class LinearSchedulerPlugin(TrainerPlugin):
"""Plugin for LinearSchedulerWithWarmup."""

def __init__(self, warmup_fraction: float, **kwargs) -> None:
def __init__(self, warmup_fraction: float) -> None:
super().__init__()

self.warmup_fraction = warmup_fraction
Expand All @@ -29,7 +30,7 @@ def after_setup(
dataset_size,
mini_batch_size,
max_epochs,
**kw,
**kwargs,
):
"""Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers."""
# calculate warmup steps
Expand All @@ -44,13 +45,13 @@ def after_setup(
self.store_learning_rate()

@TrainerPlugin.hook
def before_training_epoch(self, **kw):
def before_training_epoch(self, **kwargs):
"""Load state for anneal_with_restarts, batch_growth_annealing, logic for early stopping."""
self.store_learning_rate()
self.previous_learning_rate = self.current_learning_rate

@TrainerPlugin.hook
def after_training_batch(self, optimizer_was_run: bool, **kw):
def after_training_batch(self, optimizer_was_run: bool, **kwargs):
"""Do the scheduler step if one-cycle or linear decay."""
# skip if no optimization has happened.
if not optimizer_was_run:
Expand All @@ -60,3 +61,9 @@ def after_training_batch(self, optimizer_was_run: bool, **kw):

def __str__(self) -> str:
return f"LinearScheduler | warmup_fraction: '{self.warmup_fraction}'"

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"warmup_fraction": self.warmup_fraction,
}
9 changes: 9 additions & 0 deletions flair/trainers/plugins/functional/weight_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import WeightExtractor

Expand All @@ -7,6 +9,7 @@ class WeightExtractorPlugin(TrainerPlugin):

def __init__(self, base_path) -> None:
super().__init__()
self.base_path = base_path
self.weight_extractor = WeightExtractor(base_path)

@TrainerPlugin.hook
Expand All @@ -17,3 +20,9 @@ def after_training_batch(self, batch_no, epoch, total_number_of_batches, **kw):

if (iteration + 1) % modulo == 0:
self.weight_extractor.extract_weights(self.model.state_dict(), iteration)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
}
6 changes: 5 additions & 1 deletion flair/trainers/plugins/loggers/log_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import add_file_handler
Expand All @@ -12,10 +13,13 @@ class LogFilePlugin(TrainerPlugin):

def __init__(self, base_path) -> None:
super().__init__()

self.base_path = base_path
self.log_handler = add_file_handler(log, Path(base_path) / "training.log")

@TrainerPlugin.hook("_training_exception", "after_training")
def close_file_handler(self, **kw):
self.log_handler.close()
log.removeHandler(self.log_handler)

def get_state(self) -> Dict[str, Any]:
return {**super().get_state(), "base_path": str(self.base_path)}
11 changes: 9 additions & 2 deletions flair/trainers/plugins/loggers/loss_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

from flair.trainers.plugins.base import TrainerPlugin
from flair.trainers.plugins.metric_records import MetricName
Expand All @@ -15,9 +15,9 @@ def __init__(
super().__init__()

self.first_epoch = epoch + 1

# prepare loss logging file and set up header
self.loss_txt = init_output_file(base_path, "loss.tsv")
self.base_path = base_path

# set up all metrics to collect
self.metrics_to_collect = metrics_to_collect
Expand Down Expand Up @@ -58,6 +58,13 @@ def __init__(
# initialize the first log line
self.current_row: Optional[Dict[MetricName, str]] = None

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"metrics_to_collect": self.metrics_to_collect,
}

@TrainerPlugin.hook
def before_training_epoch(self, epoch, **kw):
"""Get the current epoch for loss file logging."""
Expand Down
8 changes: 7 additions & 1 deletion flair/trainers/plugins/loggers/metric_history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Mapping
from typing import Any, Dict, Mapping

from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -32,3 +32,9 @@ def metric_recorded(self, record):
def after_training(self, **kw):
"""Returns metric history."""
self.trainer.return_values.update(self.metric_history)

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"metrics_to_collect": dict(self.metrics_to_collect),
}
10 changes: 10 additions & 0 deletions flair/trainers/plugins/loggers/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin
from flair.training_utils import log_line
Expand All @@ -22,6 +23,7 @@ def __init__(self, log_dir=None, comment="", tracked_metrics=()) -> None:
super().__init__()
self.comment = comment
self.tracked_metrics = tracked_metrics
self.log_dir = log_dir

try:
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -56,3 +58,11 @@ def _training_finally(self, **kw):
"""Closes the writer."""
assert self.writer is not None
self.writer.close()

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"log_dir": str(self.log_dir) if self.log_dir is not None else None,
"comment": self.comment,
"tracked_metrics": self.tracked_metrics,
}
12 changes: 10 additions & 2 deletions flair/trainers/plugins/loggers/wandb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

from flair.trainers.plugins.base import TrainerPlugin

Expand Down Expand Up @@ -32,8 +33,8 @@ def emit(self, record):


class WandbLogger(TrainerPlugin):
def __init__(self, wandb, emit_alerts=True, alert_level=logging.WARNING, **kwargs) -> None:
super().__init__(**kwargs)
def __init__(self, wandb, emit_alerts=True, alert_level=logging.WARNING) -> None:
super().__init__()

self.wandb = wandb
self.emit_alerts = emit_alerts
Expand Down Expand Up @@ -70,3 +71,10 @@ def metric_recorded(self, record):
@TrainerPlugin.hook
def _training_finally(self, **kw):
self.writer.close()

def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"emit_alerts": self.emit_alerts,
"alert_level": self.alert_level,
}
6 changes: 5 additions & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,11 @@ def _initialize_model_card(self, **training_parameters):
k: str(v) if isinstance(v, Path) else v for k, v in training_parameters.items()
}

plugins = [plugin.__class__ for plugin in model_card["training_parameters"]["plugins"]]
model_card["training_parameters"] = {
k: f"{v.__module__}.{v.__name__}" if inspect.isclass(v) else v for k, v in training_parameters.items()
}

plugins = [plugin.get_state() for plugin in model_card["training_parameters"]["plugins"]]
model_card["training_parameters"]["plugins"] = plugins

return model_card
Expand Down

0 comments on commit 41b2ad4

Please sign in to comment.