diff --git a/docs/conf.py b/docs/conf.py index c1be9c0079..22ffbd0194 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,7 +31,7 @@ "navbar_end": ["darkmode-toggle", "version-switcher", "navbar-icon-links"], "show_prev_next": False, "footer_end": ["footer-links/legal-notice.html", "footer-links/x.html", "footer-links/linkedin.html"], - "secondary_sidebar_items": [] + "secondary_sidebar_items": [], } @@ -82,18 +82,18 @@ def linkcode_resolve(*args): # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] -html_title = 'Flair Documentation' +html_title = "Flair Documentation" html_css_files = [ - 'css/main.css', - 'css/header.css', - 'css/footer.css', - 'css/version-switcher.css', - 'css/sidebar.css', - 'css/tutorial.css', - 'css/api.css', - 'css/legal-notice.css', - 'css/search.css', + "css/main.css", + "css/header.css", + "css/footer.css", + "css/version-switcher.css", + "css/sidebar.css", + "css/tutorial.css", + "css/api.css", + "css/legal-notice.css", + "css/search.css", ] html_logo = "_static/flair_logo_white.svg" @@ -101,7 +101,7 @@ def linkcode_resolve(*args): # Napoleon settings napoleon_include_init_with_doc = True -napoleon_include_private_with_doc = True +napoleon_include_private_with_doc = False autodoc_default_options = { "member-order": "bysource", @@ -118,9 +118,7 @@ def linkcode_resolve(*args): } html_sidebars = { - "**": [ - "globaltoc.html" - ], + "**": ["globaltoc.html"], "index": [], } diff --git a/docs/tutorial/tutorial-basics/basic-types.md b/docs/tutorial/tutorial-basics/basic-types.md index 703a5d7cd5..5ddf247166 100644 --- a/docs/tutorial/tutorial-basics/basic-types.md +++ b/docs/tutorial/tutorial-basics/basic-types.md @@ -242,7 +242,7 @@ for label in sentence.get_labels('ner'): ### Information for each label -Each label is of class `Label` which next to the value has a score indicating confidence. It also has a pointer back to the data point to which it attaches. +Each label is of class [`Label`](#flair.data.Label) which next to the value has a score indicating confidence. It also has a pointer back to the data point to which it attaches. This means that you can print the value, the confidence and the labeled text of each label: @@ -267,3 +267,9 @@ This should print: Our color tag has a score of 1.0 since we manually added it. If a tag is predicted by our sequence labeler, the score value will indicate classifier confidence. + +### Next + +Congrats, you now understand Flair's basic types. + +Next, learn how to use [Flair models to make predictions](how-predictions-work.md). \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/entity-linking.md b/docs/tutorial/tutorial-basics/entity-linking.md index 8137c2dc5f..808d5c91ad 100644 --- a/docs/tutorial/tutorial-basics/entity-linking.md +++ b/docs/tutorial/tutorial-basics/entity-linking.md @@ -83,3 +83,15 @@ As we can see, the linker can resolve that: - the first mention of "Barcelona" refers to the soccer club "[FC Barcelona](https://en.wikipedia.org/wiki/FC_Barcelona)" - the second mention of "Barcelona" refers to the city of "[Barcelona](https://en.wikipedia.org/wiki/Barcelona)" + +### Linking biomedical entities + +If you are working with biomedical data, we have a special entity linker capable of linking +biomedical entities to specific knowledge bases. In this case, check out this [advanced tutorial on +linking biomedical entities](entity-mention-linking.md). + +### Next + +Congrats, you learned how to link entities with Flair! + +Next, let's discuss how to [predict part-of-speech tags with Flair](part-of-speech-tagging.md). \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/entity-mention-linking.md b/docs/tutorial/tutorial-basics/entity-mention-linking.md index e9d442c9e1..56803f1b02 100644 --- a/docs/tutorial/tutorial-basics/entity-mention-linking.md +++ b/docs/tutorial/tutorial-basics/entity-mention-linking.md @@ -1,6 +1,6 @@ # Using and creating entity mention linker -As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN approach](https://huggingface.co/hunflair)]. +As of Flair 0.14 we ship the [entity mention linker](#flair.models.EntityMentionLinker) - the core framework behind the [Hunflair BioNEN approach](https://huggingface.co/hunflair). You can read more at the [Hunflair2 tutorials](project:../tutorial-hunflair2/overview.md) ## Example 1: Printing Entity linking outputs to console @@ -124,5 +124,11 @@ print(result_mentions) ```{note} If you need more than the extracted ids, you can use `nen_tagger.dictionary[span_data["nen_id"]]` - to look up the [`flair.data.EntityCandidate`](#flair.data.EntityCandidate) which contains further information. -``` \ No newline at end of file + to look up the [`EntityCandidate`](#flair.data.EntityCandidate) which contains further information. +``` + +### Next + +Congrats, you learned how to link biomedical entities with Flair! + +Next, let's discuss how to [predict part-of-speech tags with Flair](part-of-speech-tagging.md). \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/how-predictions-work.md b/docs/tutorial/tutorial-basics/how-predictions-work.md index 9911f6efa5..a1dffa8913 100644 --- a/docs/tutorial/tutorial-basics/how-predictions-work.md +++ b/docs/tutorial/tutorial-basics/how-predictions-work.md @@ -76,3 +76,8 @@ the text of label.data_point is: "Washington" ``` +### Next + +Congrats, you've made your first predictions with Flair and accessed value and confidence scores of each prediction. + +Next, let's discuss specifically how to [predict named entities with Flair](tagging-entities.md). diff --git a/docs/tutorial/tutorial-basics/how-to-tag-corpus.md b/docs/tutorial/tutorial-basics/how-to-tag-corpus.md index 8aa75a4027..1dee865e34 100644 --- a/docs/tutorial/tutorial-basics/how-to-tag-corpus.md +++ b/docs/tutorial/tutorial-basics/how-to-tag-corpus.md @@ -30,3 +30,10 @@ for sentence in sentences: Using the `mini_batch_size` parameter of the [`Classifier.predict()`](#flair.nn.Classifier.predict) method, you can set the size of mini batches passed to the tagger. Depending on your resources, you might want to play around with this parameter to optimize speed. +### Next + +That's it - you completed tutorial 1! Congrats! + +You've learned how basic classes work and how to use Flair to make various predictions. + +Next, you can check out our tutorial on how to [train your own model](../tutorial-training/how-model-training-works.md). diff --git a/docs/tutorial/tutorial-basics/other-models.md b/docs/tutorial/tutorial-basics/other-models.md index dbab4f40d1..9bd02bda7e 100644 --- a/docs/tutorial/tutorial-basics/other-models.md +++ b/docs/tutorial/tutorial-basics/other-models.md @@ -150,3 +150,10 @@ We end this section with a list of all other models we currently ship with Flair | 'de-historic-reported' | historical reported speech | German | @redewiedergabe project | **87.94** (F1) | [redewiedergabe](https://github.com/redewiedergabe/tagger) | | | 'de-historic-free-indirect' | historical free-indirect speech | German | @redewiedergabe project | **87.94** (F1) | [redewiedergabe](https://github.com/redewiedergabe/tagger) | | + +### Next + +Congrats, you learned about some other models we have in Flair! + +So far, we only focused on predicting for single sentences. Next, let's discuss how +to create [predictions for a whole corpus of documents](how-to-tag-corpus.md). \ No newline at end of file diff --git a/docs/tutorial/tutorial-basics/part-of-speech-tagging.md b/docs/tutorial/tutorial-basics/part-of-speech-tagging.md index 3da1774bf1..9a9dc54f55 100644 --- a/docs/tutorial/tutorial-basics/part-of-speech-tagging.md +++ b/docs/tutorial/tutorial-basics/part-of-speech-tagging.md @@ -167,4 +167,9 @@ You choose which pre-trained model you load by passing the appropriate string to A full list of our current and community-contributed models can be browsed on the [__model hub__](https://huggingface.co/models?library=flair&sort=downloads). +### Next + +Congrats, you learned how to predict part-of-speech tags with Flair! + +Next, we'll present some [other models in Flair](other-models.md) you might find useful. diff --git a/docs/tutorial/tutorial-basics/tagging-entities.md b/docs/tutorial/tutorial-basics/tagging-entities.md index a3b41ff80c..39ccdd8e1a 100644 --- a/docs/tutorial/tutorial-basics/tagging-entities.md +++ b/docs/tutorial/tutorial-basics/tagging-entities.md @@ -200,3 +200,10 @@ You choose which pre-trained model you load by passing the appropriate string to A full list of our current and community-contributed models can be browsed on the [__model hub__](https://huggingface.co/models?library=flair&sort=downloads). + +### Next + +Congrats, you learned how to predict entities with Flair and got an overview of different models! + +Next, let's discuss how to [predict sentiment with Flair](tagging-sentiment.md). + diff --git a/docs/tutorial/tutorial-basics/tagging-sentiment.md b/docs/tutorial/tutorial-basics/tagging-sentiment.md index 0c6c9f5789..6bbebfb178 100644 --- a/docs/tutorial/tutorial-basics/tagging-sentiment.md +++ b/docs/tutorial/tutorial-basics/tagging-sentiment.md @@ -75,5 +75,9 @@ We end this section with a list of all models we currently ship with Flair: | 'de-offensive-language' | German | detecting offensive language | [GermEval 2018 Task 1](https://projects.fzai.h-da.de/iggsa/projekt/) | **75.71** (Macro F1) | +### Next +Congrats, you learned how to predict sentiment with Flair! + +Next, let's discuss how to [link entities to Wikipedia with Flair](entity-linking.md). diff --git a/flair/data.py b/flair/data.py index 8a76a82c1a..f5d902f0e6 100644 --- a/flair/data.py +++ b/flair/data.py @@ -74,7 +74,8 @@ def add_item(self, item: str) -> int: Args: item: a string for which to assign an id. - Returns: ID of string + Returns: + ID of string """ bytes_item = item.encode("utf-8") if bytes_item not in self.item2idx: @@ -88,7 +89,8 @@ def get_idx_for_item(self, item: str) -> int: Args: item: string for which ID is requested - Returns: ID of string, otherwise 0 + Returns: + ID of string, otherwise 0 """ item_encoded = item.encode("utf-8") if item_encoded in self.item2idx: @@ -108,7 +110,8 @@ def get_idx_for_items(self, items: list[str]) -> list[int]: Args: items: List of string for which IDs are requested - Returns: List of ID of strings + Returns: + List of ID of strings """ if not hasattr(self, "item2idx_not_encoded"): d = {key.decode("UTF-8"): value for key, value in self.item2idx.items()} @@ -347,6 +350,17 @@ def has_metadata(self, key: str) -> bool: return key in self._metadata def add_label(self, typename: str, value: str, score: float = 1.0, **metadata) -> "DataPoint": + """Adds a label to the :class:`DataPoint` by internally creating a :class:`Label` object. + + Args: + typename: A string that identifies the layer of annotation, such as "ner" for named entity labels or "sentiment" for sentiment labels + value: A string that sets the value of the label. + score: Optional value setting the confidence level of the label (between 0 and 1). If not set, a default confidence of 1 is used. + **metadata: Additional metadata information. + + Returns: + A pointer to itself (DataPoint object, now with an added label). + """ label = Label(self, value, score, **metadata) if typename not in self.annotation_layers: @@ -370,6 +384,17 @@ def get_label(self, label_type: Optional[str] = None, zero_tag_value: str = "O") return self.get_labels(label_type)[0] def get_labels(self, typename: Optional[str] = None) -> list[Label]: + """Returns all labels of this datapoint belonging to a specific annotation layer. + + For instance, if a data point has been labeled with `"sentiment"`-labels, you can call this function as + `get_labels("sentiment")` to return a list of all sentiment labels. + + Args: + typename: The string identifier of the annotation layer, like "sentiment" or "ner". + + Returns: + A list of :class:`Label` objects belonging to this annotation layer for this data point. + """ if typename is None: return self.labels @@ -766,7 +791,11 @@ def to_dict(self, tag_type: Optional[str] = None): class Sentence(DataPoint): - """A Sentence is a list of tokens and is used to represent a sentence or text fragment.""" + """A Sentence is a central object in Flair that represents either a single sentence or a whole text. + + Internally, it consists of a list of Token objects that represent each word in the text. Additionally, + this object stores all metadata related to a text such as labels, language code, etc. + """ def __init__( self, @@ -775,14 +804,12 @@ def __init__( language_code: Optional[str] = None, start_position: int = 0, ) -> None: - """Class to hold all metadata related to a text. - - Metadata can be tokens, labels, predictions, language code, etc. + """Create a sentence object by passing either a text or a list of tokens. Args: - text: original string (sentence), or a pre tokenized list of tokens. - use_tokenizer: Specify a custom tokenizer to split the text into tokens. The Default is - :class:`flair.tokenization.SegTokTokenizer`. If `use_tokenizer` is set to False, + text: Either pass the text as a string, or provide an already tokenized text as either a list of strings or a list of :class:`Token` objects. + use_tokenizer: You can optionally specify a custom tokenizer to split the text into tokens. By default we use + :class:`flair.tokenization.SegtokTokenizer`. If `use_tokenizer` is set to False, :class:`flair.tokenization.SpaceTokenizer` will be used instead. The tokenizer will be ignored, if `text` refers to pretokenized tokens. language_code: Language of the sentence. If not provided, `langdetect `_ @@ -1410,7 +1437,23 @@ def downsample( downsample_test: bool = True, random_seed: Optional[int] = None, ) -> "Corpus": - """Reduce all datasets in corpus proportionally to the given percentage.""" + """Randomly downsample the corpus to the given percentage (by removing data points). + + This method is an in-place operation, meaning that the Corpus object itself is modified by removing + data points. It additionally returns a pointer to itself for use in method chaining. + + Args: + percentage (float): A float value between 0. and 1. that indicates to which percentage the corpus + should be downsampled. Default value is 0.1, meaning it gets downsampled to 10%. + downsample_train (bool): Whether or not to include the training split in downsampling. Default is True. + downsample_dev (bool): Whether or not to include the dev split in downsampling. Default is True. + downsample_test (bool): Whether or not to include the test split in downsampling. Default is True. + random_seed (int): An optional random seed to make downsampling reproducible. + + Returns: + A pointer to itself for optional use in method chaining. + """ + if downsample_train and self._train is not None: self._train = self._downsample_to_proportion(self._train, percentage, random_seed) @@ -1423,6 +1466,10 @@ def downsample( return self def filter_empty_sentences(self): + """A method that filters all sentences consisting of 0 tokens. + + This is an in-place operation that directly modifies the Corpus object itself by removing these sentences. + """ log.info("Filtering empty sentences") if self._train is not None: self._train = Corpus._filter_empty_sentences(self._train) @@ -1433,6 +1480,15 @@ def filter_empty_sentences(self): log.info(self) def filter_long_sentences(self, max_charlength: int): + """ + A method that filters all sentences for which the plain text is longer than a specified number of characters. + + This is an in-place operation that directly modifies the Corpus object itself by removing these sentences. + + Args: + max_charlength: The maximum permissible character length of a sentence. + + """ log.info("Filtering long sentences") if self._train is not None: self._train = Corpus._filter_long_sentences(self._train, max_charlength) @@ -1477,7 +1533,7 @@ def _filter_empty_sentences(dataset) -> Dataset: return subset def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dictionary: - """Creates a dictionary of all tokens contained in the corpus. + """Creates a :class:`Dictionary` of all tokens contained in the corpus. By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary. If there are more than `max_tokens` tokens in the corpus, the most frequent tokens are added first. @@ -1485,10 +1541,13 @@ def make_vocab_dictionary(self, max_tokens: int = -1, min_freq: int = 1) -> Dict to be added to the dictionary. Args: - max_tokens: the maximum number of tokens that should be added to the dictionary (-1 = take all tokens) - min_freq: a token needs to occur at least `min_freq` times to be added to the dictionary (-1 = there is no limitation) + max_tokens: The maximum number of tokens that should be added to the dictionary (providing a value of "-1" + means that there is no maximum in this regard). + min_freq: A token needs to occur at least `min_freq` times to be added to the dictionary (providing a value + of "-1" means that there is no limitation in this regard). - Returns: dictionary of tokens + Returns: + A :class:`Dictionary` of all unique tokens in the corpus. """ tokens = self._get_most_common_tokens(max_tokens, min_freq) @@ -1797,7 +1856,8 @@ def make_tag_dictionary(self, tag_type: str) -> Dictionary: Args: tag_type: the label type to gather the tag labels - Returns: A Dictionary containing the labeled tags, including "O" and "" and "" + Returns: + A Dictionary containing the labeled tags, including "O" and "" and "" """ tag_dictionary: Dictionary = Dictionary(add_unk=False) diff --git a/flair/nn/model.py b/flair/nn/model.py index 8f9d0aa09f..03834afc76 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -69,12 +69,12 @@ def evaluate( Args: data_points: The labeled data_points to evaluate. gold_label_type: The label type indicating the gold labels - out_path: Optional output path to store predictions + out_path: Optional output path to store predictions. embedding_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and freshly recomputed, 'cpu' means all embeddings are stored on CPU, or 'gpu' means all embeddings are stored on GPU - mini_batch_size: The batch_size to use for predictions - main_evaluation_metric: Specify which metric to highlight as main_score - exclude_labels: Specify classes that won't be considered in evaluation + mini_batch_size: The batch_size to use for predictions. + main_evaluation_metric: Specify which metric to highlight as main_score. + exclude_labels: Specify classes that won't be considered in evaluation. gold_label_dictionary: Specify which classes should be considered, all other classes will be taken as . return_loss: Weather to additionally compute the loss on the data-points. **kwargs: Arguments that will be ignored. @@ -116,8 +116,8 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: """Saves the current model to the provided file. Args: - model_file: the model file - checkpoint: currently unused. + model_file: The model file. + checkpoint: This parameter is currently unused. """ model_state = self._get_state_dict() @@ -130,12 +130,13 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None: @classmethod def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": - """Loads the model from the given file. + """Loads a Flair model from the given file or state dictionary. Args: - model_path: the model file or the already loaded state dict + model_path: Either the path to the model (as string or `Path` variable) or the already loaded state dict. - Returns: the loaded text classifier model + Returns: + The loaded Flair model. """ # if this class is abstract, go through all inheriting classes and try to fetch and load the model if inspect.isabstract(cls): @@ -207,6 +208,14 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": return model def print_model_card(self): + """ + This method produces a log message that includes all recorded parameters the model was trained with. + + The model card includes information such as the Flair, PyTorch and Transformers versions used during training, + and the training parameters. + + Only available for models trained with with Flair >= 0.9.1. + """ if hasattr(self, "model_card"): param_out = "\n------------------------------------\n" param_out += "--------- Flair Model Card ---------\n" @@ -523,18 +532,22 @@ def predict( return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", ): - """Predicts the class labels for the given sentences. + """Uses the model to predict labels for a given set of data points. - The labels are directly added to the sentences. + The method does not directly return the predicted labels. Rather, labels are added as :class:`flair.data.Label` objects to + the respective data points. You can then access these predictions by calling :func:`flair.data.DataPoint.get_labels` + on each data point that you passed through this method. Args: - sentences: list of sentences - mini_batch_size: mini batch size to use - return_probabilities_for_all_classes: return probabilities for all classes instead of only best predicted - verbose: set to True to display a progress bar - return_loss: set to True to return loss - label_name: set this to change the name of the label type that is predicted # noqa: E501 - embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. # noqa: E501 + sentences: The data points for which the model should predict labels, most commonly Sentence objects. + mini_batch_size: The mini batch size to use. Setting this value higher typically makes predictions faster, + but also costs more memory. + return_probabilities_for_all_classes: If set to True, the model will store probabilities for all classes + instead of only the predicted class. + verbose: If set to True, will display a progress bar while predicting. By default, this parameter is set to False. + return_loss: Set this to True to return loss (only possible if gold labels are set for the sentences). + label_name: Optional parameter that if set, changes the identifier of the label type that is predicted. # noqa: E501 + embedding_storage_mode: Default is 'none' which is always best. Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory. # noqa: E501 """ raise NotImplementedError diff --git a/flair/splitter.py b/flair/splitter.py index 2b6c90cd7f..6246969f28 100644 --- a/flair/splitter.py +++ b/flair/splitter.py @@ -16,16 +16,34 @@ class SentenceSplitter(ABC): r"""An abstract class representing a :class:`SentenceSplitter`. Sentence splitters are used to represent algorithms and models to split plain text into - sentences and individual tokens / words. All subclasses should overwrite :meth:`splits`, - which splits the given plain text into a sequence of sentences (:class:`Sentence`). The - individual sentences are in turn subdivided into tokens / words. In most cases, this can - be controlled by passing custom implementation of :class:`Tokenizer`. + sentences and individual tokens / words. All subclasses should overwrite :func:`split`, + which splits the given plain text into a list of :class:`flair.data.Sentence` objects. The + individual sentences are in turn subdivided into tokens. In most cases, this can + be controlled by passing custom implementation of :class:`flair.tokenization.Tokenizer`. Moreover, subclasses may overwrite :meth:`name`, returning a unique identifier representing the sentence splitter's configuration. + + The most common class in Flair that implements this base class is :class:`SegtokSentenceSplitter`. """ - def split(self, text: str, link_sentences: Optional[bool] = True) -> list[Sentence]: + def split(self, text: str, link_sentences: bool = True) -> list[Sentence]: + """ + Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects. + + If link_sentences is set (by default, it is). The :class:`flair.data.Sentence` objects will include pointers + to the preceding and following sentences in the original text. This way, the original sequence information will + always be preserved. + + Args: + text (str): The plain text to split. + link_sentences (bool): If set to True, :class:`flair.data.Sentence` objects will include pointers + to the preceding and following sentences in the original text. + + Returns: + A list of :class:`flair.data.Sentence` objects that each represent one sentence in the given text. + + """ sentences = self._perform_split(text) if not link_sentences: return sentences @@ -39,10 +57,12 @@ def _perform_split(self, text: str) -> list[Sentence]: @property def name(self) -> str: + """A string identifier of the sentence splitter.""" return self.__class__.__name__ @property def tokenizer(self) -> Tokenizer: + """The :class:`flair.tokenization.Tokenizer` class used to tokenize sentences after they are split.""" raise NotImplementedError @tokenizer.setter diff --git a/flair/trainers/plugins/base.py b/flair/trainers/plugins/base.py index dcf7240a83..33e787a063 100644 --- a/flair/trainers/plugins/base.py +++ b/flair/trainers/plugins/base.py @@ -59,6 +59,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None: @property def plugins(self): + """Returns all plugins attached to this instance as a list of :class:`BasePlugin`. + + Returns: + List of :class:`BasePlugin` instances attached to this `Pluggable`. + """ return self._plugins def append_plugin(self, plugin): diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index cf3bb80eb6..03879a2b13 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -42,6 +42,22 @@ class ModelTrainer(Pluggable): + """Use this class to train a Flair model. + + The ModelTrainer is initialized using a :class:`flair.nn.Model` (the architecture you want to train) and a + :class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model). It offers two main training + functions for the two main modes of training a model: (1) :func:`train`, which is used to train a model from scratch or + to fit a classification head on a frozen transformer language model. (2) :func:`fine_tune`, which is used if you + do not freeze the transformer language model and rather fine-tune it for a specific task. + + Additionally, there is also a `train_custom` method that allows you to fully customize the training run. + + ModelTrainer inherits from :class:`flair.trainers.plugins.base.Pluggable` and thus uses a plugin system to inject + specific functionality into the training process. You can add any number of plugins to the above-mentioned training + modes. For instance, if you want to use an annealing scheduler during training, you can add the + :class:`flair.trainers.plugins.functional.AnnealingPlugin` plugin to the train command. + """ + valid_events = { "after_setup", "before_training_epoch", @@ -59,11 +75,14 @@ class ModelTrainer(Pluggable): } def __init__(self, model: flair.nn.Model, corpus: Corpus) -> None: - """Initialize a model trainer. + """Initialize a model trainer by passing a :class:`flair.nn.Model` (the architecture you want to train) and a + :class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model). Args: - model: The model that you want to train. The model should inherit from flair.nn.Model # noqa: E501 - corpus: The dataset used to train the model, should be of type Corpus + model: The model that you want to train. The model should inherit from :class:`flair.nn.Model`. So for + instance you should pass a :class:`flair.models.TextClassifier` if you want to train a text classifier, + or :class:`flair.models.SequenceTagger` if you want to train an RNN-based sequence labeler. + corpus: The dataset (of type :class:`flair.data.Corpus`) used to train the model. """ super().__init__() self.model: flair.nn.Model = model @@ -346,7 +365,7 @@ def train_custom( plugins: Optional[list[TrainerPlugin]] = None, **kwargs, ) -> dict: - """Trains any class that implements the flair.nn.Model interface. + """Trains any class that implements the :class:`flair.nn.Model` interface. Args: base_path: Main path to which all output during training is logged and models are saved