From 897b3c2c8ff48ba62c16a3c7625514abcf2c274a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20S=C3=A4nger?= Date: Mon, 15 Jan 2024 17:03:13 +0100 Subject: [PATCH] Simplify persisting and loading models --- flair/models/sequence_tagger_model.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index c61c1721b..edd1f5eff 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -1238,31 +1238,13 @@ def __init__(self, *args, augmentation_strategy: SentenceAugmentationStrategy, * def _get_state_dict(self): state = super()._get_state_dict() - class_name = ".".join([self.augmentation_strategy.__module__, self.augmentation_strategy.__class__.__name__]) - - state["augmentation_strategy_cls"] = class_name - state["augmentation_strategy_state"] = self.augmentation_strategy._get_state_dict() + state["augmentation_strategy"] = self.augmentation_strategy return state @classmethod def _init_model_with_state_dict(cls, state, **kwargs): - subclasses = [ - (".".join([subclass.__module__, subclass.__name__]), subclass) - for subclass in get_non_abstract_subclasses(SentenceAugmentationStrategy) - ] - - aug_strategy_cls_name = state.get("augmentation_strategy_cls") - strategy = None - - for subclass_name, subclass in subclasses: - if aug_strategy_cls_name == subclass_name: - strategy = subclass._init_strategy_with_state_dict(state.get("augmentation_strategy_state")) - break - - if strategy is None: - raise AssertionError(f"Can't reload augmentation strategy {aug_strategy_cls_name}") - + strategy = state["augmentation_strategy"] return super()._init_model_with_state_dict(state, augmentation_strategy=strategy, **kwargs) @classmethod