From 9161c83eca4cbe3da88bde50f6a171eaef6309cd Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 2 Feb 2024 19:31:10 +0100 Subject: [PATCH] update code to newest versions of ruff and black --- flair/data.py | 8 +++----- flair/datasets/base.py | 4 ++-- flair/datasets/biomedical.py | 7 ++----- flair/datasets/sequence_labeling.py | 8 +++----- flair/datasets/text_image.py | 2 +- flair/file_utils.py | 1 + flair/models/entity_linker_model.py | 6 +++--- flair/models/language_model.py | 6 +++--- flair/nn/distance/euclidean.py | 1 - flair/nn/model.py | 14 +++++++------- flair/trainers/trainer.py | 2 +- pyproject.toml | 5 +++-- 12 files changed, 29 insertions(+), 35 deletions(-) diff --git a/flair/data.py b/flair/data.py index 9d0646b94d..10e74e7823 100644 --- a/flair/data.py +++ b/flair/data.py @@ -361,7 +361,7 @@ def get_labels(self, typename: Optional[str] = None): if typename is None: return self.labels - return self.annotation_layers[typename] if typename in self.annotation_layers else [] + return self.annotation_layers.get(typename, []) @property def labels(self) -> List[Label]: @@ -987,12 +987,10 @@ def get_span(self, start: int, stop: int): return self[span_slice] @typing.overload - def __getitem__(self, idx: int) -> Token: - ... + def __getitem__(self, idx: int) -> Token: ... @typing.overload - def __getitem__(self, s: slice) -> Span: - ... + def __getitem__(self, s: slice) -> Span: ... def __getitem__(self, subscript): if isinstance(subscript, slice): diff --git a/flair/datasets/base.py b/flair/datasets/base.py index 2ba0aabab3..a38d0b1321 100644 --- a/flair/datasets/base.py +++ b/flair/datasets/base.py @@ -182,7 +182,7 @@ def __init__( for document in self.__cursor.find(filter=query, skip=start, limit=0): sentence = self._parse_document_to_sentence( document[self.text], - [document[_] if _ in document else "" for _ in self.categories], + [document.get(c, "") for c in self.categories], tokenizer, ) if sentence is not None and len(sentence.tokens) > 0: @@ -225,7 +225,7 @@ def __getitem__(self, index: int = 0) -> Sentence: document = self.__cursor.find_one({"_id": index}) sentence = self._parse_document_to_sentence( document[self.text], - [document[_] if _ in document else "" for _ in self.categories], + [document.get(c, "") for c in self.categories], self.tokenizer, ) return sentence diff --git a/flair/datasets/biomedical.py b/flair/datasets/biomedical.py index bde405063c..71b3b7c55d 100644 --- a/flair/datasets/biomedical.py +++ b/flair/datasets/biomedical.py @@ -2191,11 +2191,8 @@ def patch_training_file(orig_train_file: Path, patched_file: Path): 3249: '10923035\t711\t761\tgeneralized epilepsy and febrile seizures " plus "\tSpecificDisease\tD004829+D003294\n' } with orig_train_file.open(encoding="utf-8") as input, patched_file.open("w", encoding="utf-8") as output: - line_no = 1 - - for line in input: - output.write(patch_lines[line_no] if line_no in patch_lines else line) - line_no += 1 + for line_no, line in enumerate(input, start=1): + output.write(patch_lines.get(line_no, line)) @staticmethod def parse_input_file(input_file: Path): diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index b551b1e699..820f3ba570 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -2762,14 +2762,12 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path): with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open( "w", encoding="utf-8" ) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev: - k = 0 - for line in file.readlines(): - k += 1 + for k, line in enumerate(file.readlines(), start=1): if k <= train_len: train.write(line) - elif k > train_len and k <= (train_len + test_len): + elif train_len < k <= (train_len + test_len): test.write(line) - elif k > (train_len + test_len) and k <= num_lines: + elif (train_len + test_len) < k <= num_lines: dev.write(line) diff --git a/flair/datasets/text_image.py b/flair/datasets/text_image.py index 31c7e55156..f7baf72be9 100644 --- a/flair/datasets/text_image.py +++ b/flair/datasets/text_image.py @@ -63,7 +63,7 @@ def identity(x): return x preprocessor = identity - if "lowercase" in kwargs and kwargs["lowercase"]: + if kwargs.get("lowercase"): preprocessor = str.lower for image_info in dataset_info: diff --git a/flair/file_utils.py b/flair/file_utils.py index 7f0ba5f9e7..a3db6458d8 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -1,4 +1,5 @@ """Utilities for working with the local dataset cache. Copied from AllenNLP.""" + import base64 import functools import io diff --git a/flair/models/entity_linker_model.py b/flair/models/entity_linker_model.py index 3e46f3a2dd..838664e402 100644 --- a/flair/models/entity_linker_model.py +++ b/flair/models/entity_linker_model.py @@ -108,9 +108,9 @@ def __init__( super().__init__( embeddings=embeddings, label_dictionary=label_dictionary, - final_embedding_size=embeddings.embedding_length * 2 - if pooling_operation == "first_last" - else embeddings.embedding_length, + final_embedding_size=( + embeddings.embedding_length * 2 if pooling_operation == "first_last" else embeddings.embedding_length + ), **classifierargs, ) diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 97af6f8eb9..29dedbc755 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -215,9 +215,9 @@ def load_language_model(cls, model_file: Union[Path, str], has_decoder=True): def load_checkpoint(cls, model_file: Union[Path, str]): state = torch.load(str(model_file), map_location=flair.device) - epoch = state["epoch"] if "epoch" in state else None - split = state["split"] if "split" in state else None - loss = state["loss"] if "loss" in state else None + epoch = state.get("epoch") + split = state.get("split") + loss = state.get("loss") document_delimiter = state.get("document_delimiter", "\n") optimizer_state_dict = state.get("optimizer_state_dict") diff --git a/flair/nn/distance/euclidean.py b/flair/nn/distance/euclidean.py index 311709ff3e..d3a06d4123 100644 --- a/flair/nn/distance/euclidean.py +++ b/flair/nn/distance/euclidean.py @@ -17,7 +17,6 @@ Source: https://github.com/asappresearch/dynamic-classification/blob/55beb5a48406c187674bea40487c011e8fa45aab/distance/euclidean.py """ - import torch from torch import Tensor, nn diff --git a/flair/nn/model.py b/flair/nn/model.py index 5d553fd42f..b339f19947 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -339,7 +339,7 @@ def evaluate( true_values_span_aligned = [] predicted_values_span_aligned = [] for span in all_spans: - list_of_gold_values_for_span = all_true_values[span] if span in all_true_values else ["O"] + list_of_gold_values_for_span = all_true_values.get(span, ["O"]) # delete exluded labels if exclude_labels is given for excluded_label in exclude_labels: if excluded_label in list_of_gold_values_for_span: @@ -348,9 +348,7 @@ def evaluate( if not list_of_gold_values_for_span: continue true_values_span_aligned.append(list_of_gold_values_for_span) - predicted_values_span_aligned.append( - all_predicted_values[span] if span in all_predicted_values else ["O"] - ) + predicted_values_span_aligned.append(all_predicted_values.get(span, ["O"])) # write all_predicted_values to out_file if set if out_path: @@ -700,9 +698,11 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens else: return torch.tensor( [ - self.label_dictionary.get_idx_for_item(label[0]) - if len(label) > 0 - else self.label_dictionary.get_idx_for_item("O") + ( + self.label_dictionary.get_idx_for_item(label[0]) + if len(label) > 0 + else self.label_dictionary.get_idx_for_item("O") + ) for label in labels ], dtype=torch.long, diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 65617043b1..2afb72bc8b 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -816,7 +816,7 @@ def train_custom( def _get_current_lr_and_momentum(self, batch_count): current_learning_rate = [group["lr"] for group in self.optimizer.param_groups] - momentum = [group["momentum"] if "momentum" in group else 0 for group in self.optimizer.param_groups] + momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups] lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate]) momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum]) self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count)) diff --git a/pyproject.toml b/pyproject.toml index 12e62fdd86..e4986413c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ ignore_errors = true line-length = 120 target-version = "py37" +[tool.ruff.lint] #select = ["ALL"] # Uncommit to autofix all the things select = [ "C4", @@ -95,12 +96,12 @@ unfixable = [ "F841", # Do not remove unused variables automatically ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "flair/embeddings/legacy.py" = ["D205"] "scripts/*" = ["INP001"] # no need for __ini__ for scripts "flair/datasets/*" = ["D417"] # need to fix datasets in a unified way later. -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" [tool.pydocstyle]