Skip to content

Commit

Permalink
update code to newest versions of ruff and black
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Feb 2, 2024
1 parent d4332d2 commit 9161c83
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 35 deletions.
8 changes: 3 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_labels(self, typename: Optional[str] = None):
if typename is None:
return self.labels

return self.annotation_layers[typename] if typename in self.annotation_layers else []
return self.annotation_layers.get(typename, [])

@property
def labels(self) -> List[Label]:
Expand Down Expand Up @@ -987,12 +987,10 @@ def get_span(self, start: int, stop: int):
return self[span_slice]

@typing.overload
def __getitem__(self, idx: int) -> Token:
...
def __getitem__(self, idx: int) -> Token: ...

@typing.overload
def __getitem__(self, s: slice) -> Span:
...
def __getitem__(self, s: slice) -> Span: ...

def __getitem__(self, subscript):
if isinstance(subscript, slice):
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
for document in self.__cursor.find(filter=query, skip=start, limit=0):
sentence = self._parse_document_to_sentence(
document[self.text],
[document[_] if _ in document else "" for _ in self.categories],
[document.get(c, "") for c in self.categories],
tokenizer,
)
if sentence is not None and len(sentence.tokens) > 0:
Expand Down Expand Up @@ -225,7 +225,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
document = self.__cursor.find_one({"_id": index})
sentence = self._parse_document_to_sentence(
document[self.text],
[document[_] if _ in document else "" for _ in self.categories],
[document.get(c, "") for c in self.categories],
self.tokenizer,
)
return sentence
Expand Down
7 changes: 2 additions & 5 deletions flair/datasets/biomedical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,11 +2191,8 @@ def patch_training_file(orig_train_file: Path, patched_file: Path):
3249: '10923035\t711\t761\tgeneralized epilepsy and febrile seizures " plus "\tSpecificDisease\tD004829+D003294\n'
}
with orig_train_file.open(encoding="utf-8") as input, patched_file.open("w", encoding="utf-8") as output:
line_no = 1

for line in input:
output.write(patch_lines[line_no] if line_no in patch_lines else line)
line_no += 1
for line_no, line in enumerate(input, start=1):
output.write(patch_lines.get(line_no, line))

@staticmethod
def parse_input_file(input_file: Path):
Expand Down
8 changes: 3 additions & 5 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,14 +2762,12 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path):
with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open(
"w", encoding="utf-8"
) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev:
k = 0
for line in file.readlines():
k += 1
for k, line in enumerate(file.readlines(), start=1):
if k <= train_len:
train.write(line)
elif k > train_len and k <= (train_len + test_len):
elif train_len < k <= (train_len + test_len):
test.write(line)
elif k > (train_len + test_len) and k <= num_lines:
elif (train_len + test_len) < k <= num_lines:
dev.write(line)


Expand Down
2 changes: 1 addition & 1 deletion flair/datasets/text_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def identity(x):
return x

preprocessor = identity
if "lowercase" in kwargs and kwargs["lowercase"]:
if kwargs.get("lowercase"):
preprocessor = str.lower

for image_info in dataset_info:
Expand Down
1 change: 1 addition & 0 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with the local dataset cache. Copied from AllenNLP."""

import base64
import functools
import io
Expand Down
6 changes: 3 additions & 3 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(
super().__init__(
embeddings=embeddings,
label_dictionary=label_dictionary,
final_embedding_size=embeddings.embedding_length * 2
if pooling_operation == "first_last"
else embeddings.embedding_length,
final_embedding_size=(
embeddings.embedding_length * 2 if pooling_operation == "first_last" else embeddings.embedding_length
),
**classifierargs,
)

Expand Down
6 changes: 3 additions & 3 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def load_language_model(cls, model_file: Union[Path, str], has_decoder=True):
def load_checkpoint(cls, model_file: Union[Path, str]):
state = torch.load(str(model_file), map_location=flair.device)

epoch = state["epoch"] if "epoch" in state else None
split = state["split"] if "split" in state else None
loss = state["loss"] if "loss" in state else None
epoch = state.get("epoch")
split = state.get("split")
loss = state.get("loss")
document_delimiter = state.get("document_delimiter", "\n")

optimizer_state_dict = state.get("optimizer_state_dict")
Expand Down
1 change: 0 additions & 1 deletion flair/nn/distance/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Source: https://github.com/asappresearch/dynamic-classification/blob/55beb5a48406c187674bea40487c011e8fa45aab/distance/euclidean.py
"""


import torch
from torch import Tensor, nn

Expand Down
14 changes: 7 additions & 7 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def evaluate(
true_values_span_aligned = []
predicted_values_span_aligned = []
for span in all_spans:
list_of_gold_values_for_span = all_true_values[span] if span in all_true_values else ["O"]
list_of_gold_values_for_span = all_true_values.get(span, ["O"])
# delete exluded labels if exclude_labels is given
for excluded_label in exclude_labels:
if excluded_label in list_of_gold_values_for_span:
Expand All @@ -348,9 +348,7 @@ def evaluate(
if not list_of_gold_values_for_span:
continue
true_values_span_aligned.append(list_of_gold_values_for_span)
predicted_values_span_aligned.append(
all_predicted_values[span] if span in all_predicted_values else ["O"]
)
predicted_values_span_aligned.append(all_predicted_values.get(span, ["O"]))

# write all_predicted_values to out_file if set
if out_path:
Expand Down Expand Up @@ -700,9 +698,11 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens
else:
return torch.tensor(
[
self.label_dictionary.get_idx_for_item(label[0])
if len(label) > 0
else self.label_dictionary.get_idx_for_item("O")
(
self.label_dictionary.get_idx_for_item(label[0])
if len(label) > 0
else self.label_dictionary.get_idx_for_item("O")
)
for label in labels
],
dtype=torch.long,
Expand Down
2 changes: 1 addition & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def train_custom(

def _get_current_lr_and_momentum(self, batch_count):
current_learning_rate = [group["lr"] for group in self.optimizer.param_groups]
momentum = [group["momentum"] if "momentum" in group else 0 for group in self.optimizer.param_groups]
momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups]
lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate])
momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum])
self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count))
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ignore_errors = true
line-length = 120
target-version = "py37"

[tool.ruff.lint]
#select = ["ALL"] # Uncommit to autofix all the things
select = [
"C4",
Expand Down Expand Up @@ -95,12 +96,12 @@ unfixable = [
"F841", # Do not remove unused variables automatically
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"flair/embeddings/legacy.py" = ["D205"]
"scripts/*" = ["INP001"] # no need for __ini__ for scripts
"flair/datasets/*" = ["D417"] # need to fix datasets in a unified way later.

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.pydocstyle]
Expand Down

0 comments on commit 9161c83

Please sign in to comment.