diff --git a/flair/datasets/entity_linking.py b/flair/datasets/entity_linking.py index 801144ca8..a515e0f3a 100644 --- a/flair/datasets/entity_linking.py +++ b/flair/datasets/entity_linking.py @@ -946,7 +946,7 @@ def _text_to_cols(self, sentence: Sentence, links: list, outfile): links: array containing information about the starting and ending position of an entity mention, as well as its corresponding wiki tag outfile: file, to which the output is written """ - for i in range(0, len(sentence)): + for i in range(len(sentence)): # If there are annotated entity mentions for given post title or a comment thread if links: # Keep track which is the correct corresponding entity link, in cases where there is >1 link in a sentence diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 849c7c899..d2e4215c8 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -1178,8 +1178,7 @@ def dataset_document_iterator(cls, file_path: Union[Path, str]) -> Iterator[List def sentence_iterator(cls, file_path: Union[Path, str]) -> Iterator: """An iterator over the sentences in an individual CONLL formatted file.""" for document in cls.dataset_document_iterator(file_path): - for sentence in document: - yield sentence + yield from document class CONLL_03(ColumnCorpus): @@ -2490,7 +2489,7 @@ def _add_IOB_tags(self, data_file: Union[str, Path], encoding: str = "utf8", ner """ def add_I_prefix(current_line: List[str], ner: int, tag: str): - for i in range(0, len(current_line)): + for i in range(len(current_line)): if i == 0: f.write(line_list[i]) elif i == ner: @@ -2508,7 +2507,7 @@ def add_I_prefix(current_line: List[str], ner: int, tag: str): if len(line_list) > 2: # word with tags ner_tag = line_list[ner_column] if ner_tag in ["0", "O"]: # no chunk - for i in range(0, len(line_list)): + for i in range(len(line_list)): if i == 0: f.write(line_list[i]) elif i == ner_column: diff --git a/flair/models/sequence_tagger_utils/viterbi.py b/flair/models/sequence_tagger_utils/viterbi.py index 050697b18..73c10fb67 100644 --- a/flair/models/sequence_tagger_utils/viterbi.py +++ b/flair/models/sequence_tagger_utils/viterbi.py @@ -119,7 +119,7 @@ def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor): matrix_indices = [ [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)] - + [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)] + + [s[i] + (s[i + 1] * self.tagset_size) for i in range(len(s) - 1)] for s in targets_per_sentence ] diff --git a/flair/trainers/plugins/functional/linear_scheduler.py b/flair/trainers/plugins/functional/linear_scheduler.py index ad1cb8612..51b295c7d 100644 --- a/flair/trainers/plugins/functional/linear_scheduler.py +++ b/flair/trainers/plugins/functional/linear_scheduler.py @@ -52,7 +52,6 @@ def before_training_epoch(self, **kw): @TrainerPlugin.hook def after_training_batch(self, optimizer_was_run: bool, **kw): """Do the scheduler step if one-cycle or linear decay.""" - # skip if no optimization has happened. if not optimizer_was_run: return diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 43e99d6ca..7db6a7d17 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -348,6 +348,7 @@ def train_custom( monitor_train_sample: Set this to evaluate on a sample of the train data at the end of each epoch. If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample a percentage of data points from train. + max_grad_norm: If not None, gradients are clipped to this value before an optimizer.step is called. use_final_model_for_eval: If True, the final model is used for the final evaluation. If False, the model from the best epoch as determined by main_evaluation_metric is used for the final evaluation. gold_label_dictionary_for_eval: Set to force evaluation to use a particular label dictionary @@ -364,6 +365,7 @@ def train_custom( be saved each 5 epochs. Default is 0 which means no model saving. create_file_logs: If True, logging output is written to a file create_loss_file: If True, a loss file logging output is created + use_amp: If True, uses the torch automatic mixed precision write_weights: If True, write weights to weights.txt on each batch logging event. plugins: Any additional plugins you want to pass to the trainer **kwargs: Additional arguments, for instance for the optimizer