Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dependencies maintainance updates #3402

Merged
merged 3 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_labels(self, typename: Optional[str] = None):
if typename is None:
return self.labels

return self.annotation_layers[typename] if typename in self.annotation_layers else []
return self.annotation_layers.get(typename, [])

@property
def labels(self) -> List[Label]:
Expand Down Expand Up @@ -987,12 +987,10 @@ def get_span(self, start: int, stop: int):
return self[span_slice]

@typing.overload
def __getitem__(self, idx: int) -> Token:
...
def __getitem__(self, idx: int) -> Token: ...

@typing.overload
def __getitem__(self, s: slice) -> Span:
...
def __getitem__(self, s: slice) -> Span: ...

def __getitem__(self, subscript):
if isinstance(subscript, slice):
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
for document in self.__cursor.find(filter=query, skip=start, limit=0):
sentence = self._parse_document_to_sentence(
document[self.text],
[document[_] if _ in document else "" for _ in self.categories],
[document.get(c, "") for c in self.categories],
tokenizer,
)
if sentence is not None and len(sentence.tokens) > 0:
Expand Down Expand Up @@ -225,7 +225,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
document = self.__cursor.find_one({"_id": index})
sentence = self._parse_document_to_sentence(
document[self.text],
[document[_] if _ in document else "" for _ in self.categories],
[document.get(c, "") for c in self.categories],
self.tokenizer,
)
return sentence
Expand Down
7 changes: 2 additions & 5 deletions flair/datasets/biomedical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,11 +2191,8 @@ def patch_training_file(orig_train_file: Path, patched_file: Path):
3249: '10923035\t711\t761\tgeneralized epilepsy and febrile seizures " plus "\tSpecificDisease\tD004829+D003294\n'
}
with orig_train_file.open(encoding="utf-8") as input, patched_file.open("w", encoding="utf-8") as output:
line_no = 1

for line in input:
output.write(patch_lines[line_no] if line_no in patch_lines else line)
line_no += 1
for line_no, line in enumerate(input, start=1):
output.write(patch_lines.get(line_no, line))

@staticmethod
def parse_input_file(input_file: Path):
Expand Down
8 changes: 3 additions & 5 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,14 +2762,12 @@ def _create_datasets(self, data_file: Union[str, Path], data_folder: Path):
with (data_folder / "train.txt").open("w", encoding="utf-8") as train, (data_folder / "test.txt").open(
"w", encoding="utf-8"
) as test, (data_folder / "dev.txt").open("w", encoding="utf-8") as dev:
k = 0
for line in file.readlines():
k += 1
for k, line in enumerate(file.readlines(), start=1):
if k <= train_len:
train.write(line)
elif k > train_len and k <= (train_len + test_len):
elif train_len < k <= (train_len + test_len):
test.write(line)
elif k > (train_len + test_len) and k <= num_lines:
elif (train_len + test_len) < k <= num_lines:
dev.write(line)


Expand Down
2 changes: 1 addition & 1 deletion flair/datasets/text_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def identity(x):
return x

preprocessor = identity
if "lowercase" in kwargs and kwargs["lowercase"]:
if kwargs.get("lowercase"):
preprocessor = str.lower

for image_info in dataset_info:
Expand Down
1 change: 1 addition & 0 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with the local dataset cache. Copied from AllenNLP."""

import base64
import functools
import io
Expand Down
6 changes: 3 additions & 3 deletions flair/models/entity_linker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(
super().__init__(
embeddings=embeddings,
label_dictionary=label_dictionary,
final_embedding_size=embeddings.embedding_length * 2
if pooling_operation == "first_last"
else embeddings.embedding_length,
final_embedding_size=(
embeddings.embedding_length * 2 if pooling_operation == "first_last" else embeddings.embedding_length
),
**classifierargs,
)

Expand Down
6 changes: 3 additions & 3 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def load_language_model(cls, model_file: Union[Path, str], has_decoder=True):
def load_checkpoint(cls, model_file: Union[Path, str]):
state = torch.load(str(model_file), map_location=flair.device)

epoch = state["epoch"] if "epoch" in state else None
split = state["split"] if "split" in state else None
loss = state["loss"] if "loss" in state else None
epoch = state.get("epoch")
split = state.get("split")
loss = state.get("loss")
document_delimiter = state.get("document_delimiter", "\n")

optimizer_state_dict = state.get("optimizer_state_dict")
Expand Down
1 change: 0 additions & 1 deletion flair/nn/distance/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Source: https://github.com/asappresearch/dynamic-classification/blob/55beb5a48406c187674bea40487c011e8fa45aab/distance/euclidean.py
"""


import torch
from torch import Tensor, nn

Expand Down
14 changes: 7 additions & 7 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def evaluate(
true_values_span_aligned = []
predicted_values_span_aligned = []
for span in all_spans:
list_of_gold_values_for_span = all_true_values[span] if span in all_true_values else ["O"]
list_of_gold_values_for_span = all_true_values.get(span, ["O"])
# delete exluded labels if exclude_labels is given
for excluded_label in exclude_labels:
if excluded_label in list_of_gold_values_for_span:
Expand All @@ -348,9 +348,7 @@ def evaluate(
if not list_of_gold_values_for_span:
continue
true_values_span_aligned.append(list_of_gold_values_for_span)
predicted_values_span_aligned.append(
all_predicted_values[span] if span in all_predicted_values else ["O"]
)
predicted_values_span_aligned.append(all_predicted_values.get(span, ["O"]))

# write all_predicted_values to out_file if set
if out_path:
Expand Down Expand Up @@ -700,9 +698,11 @@ def _prepare_label_tensor(self, prediction_data_points: List[DT2]) -> torch.Tens
else:
return torch.tensor(
[
self.label_dictionary.get_idx_for_item(label[0])
if len(label) > 0
else self.label_dictionary.get_idx_for_item("O")
(
self.label_dictionary.get_idx_for_item(label[0])
if len(label) > 0
else self.label_dictionary.get_idx_for_item("O")
)
for label in labels
],
dtype=torch.long,
Expand Down
2 changes: 0 additions & 2 deletions flair/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ class ReduceLRWDOnPlateau(ReduceLROnPlateau):
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
Expand Down
6 changes: 2 additions & 4 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,9 @@ def train(
optimizer.load_state_dict(self.optimizer_state)

if isinstance(optimizer, (AdamW, SGDW)):
scheduler: ReduceLROnPlateau = ReduceLRWDOnPlateau(
optimizer, verbose=True, factor=anneal_factor, patience=patience
)
scheduler: ReduceLROnPlateau = ReduceLRWDOnPlateau(optimizer, factor=anneal_factor, patience=patience)
else:
scheduler = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor, patience=patience)
scheduler = ReduceLROnPlateau(optimizer, factor=anneal_factor, patience=patience)

training_generator = DataLoader(self.corpus.train, shuffle=False, num_workers=num_workers)

Expand Down
1 change: 0 additions & 1 deletion flair/trainers/plugins/functional/anneal_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def after_setup(
patience=self.patience,
initial_extra_patience=self.initial_extra_patience,
mode=anneal_mode,
verbose=False,
optimizer=self.trainer.optimizer,
)

Expand Down
2 changes: 1 addition & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def train_custom(

def _get_current_lr_and_momentum(self, batch_count):
current_learning_rate = [group["lr"] for group in self.optimizer.param_groups]
momentum = [group["momentum"] if "momentum" in group else 0 for group in self.optimizer.param_groups]
momentum = [group.get("momentum", 0) for group in self.optimizer.param_groups]
lr_info = " - lr: " + ",".join([f"{m:.6f}" for m in current_learning_rate])
momentum_info = " - momentum: " + ",".join([f"{m:.6f}" for m in momentum])
self._record(MetricRecord.scalar_list("learning_rate", current_learning_rate, batch_count))
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ignore_errors = true
line-length = 120
target-version = "py37"

[tool.ruff.lint]
#select = ["ALL"] # Uncommit to autofix all the things
select = [
"C4",
Expand Down Expand Up @@ -95,12 +96,12 @@ unfixable = [
"F841", # Do not remove unused variables automatically
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"flair/embeddings/legacy.py" = ["D205"]
"scripts/*" = ["INP001"] # no need for __ini__ for scripts
"flair/datasets/*" = ["D417"] # need to fix datasets in a unified way later.

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.pydocstyle]
Expand Down
Loading