Skip to content

Commit

Permalink
Simplify persisting and loading models
Browse files Browse the repository at this point in the history
  • Loading branch information
Mario Sänger committed Jan 15, 2024
1 parent 96c1eaf commit 897b3c2
Showing 1 changed file with 2 additions and 20 deletions.
22 changes: 2 additions & 20 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,31 +1238,13 @@ def __init__(self, *args, augmentation_strategy: SentenceAugmentationStrategy, *

def _get_state_dict(self):
state = super()._get_state_dict()
class_name = ".".join([self.augmentation_strategy.__module__, self.augmentation_strategy.__class__.__name__])

state["augmentation_strategy_cls"] = class_name
state["augmentation_strategy_state"] = self.augmentation_strategy._get_state_dict()
state["augmentation_strategy"] = self.augmentation_strategy

return state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
subclasses = [
(".".join([subclass.__module__, subclass.__name__]), subclass)
for subclass in get_non_abstract_subclasses(SentenceAugmentationStrategy)
]

aug_strategy_cls_name = state.get("augmentation_strategy_cls")
strategy = None

for subclass_name, subclass in subclasses:
if aug_strategy_cls_name == subclass_name:
strategy = subclass._init_strategy_with_state_dict(state.get("augmentation_strategy_state"))
break

if strategy is None:
raise AssertionError(f"Can't reload augmentation strategy {aug_strategy_cls_name}")

strategy = state["augmentation_strategy"]
return super()._init_model_with_state_dict(state, augmentation_strategy=strategy, **kwargs)

@classmethod
Expand Down

0 comments on commit 897b3c2

Please sign in to comment.