From 4e159925f961d2475f0694cefaf1840c52f383fd Mon Sep 17 00:00:00 2001 From: elenamer Date: Fri, 13 Dec 2024 16:35:16 +0100 Subject: [PATCH] formatting and fix some typing --- flair/datasets/sequence_labeling.py | 41 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 76c842a68..b2ab2f45d 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -1,9 +1,9 @@ import copy +import gzip import json import logging import os import re -import gzip import shutil import tarfile import tempfile @@ -5251,12 +5251,13 @@ def __init__( in_memory (bool): If True the dataset is kept in memory achieving speedups in training. **corpusargs: The arguments propagated to :meth:'flair.datasets.ColumnCorpus.__init__'. """ - VALUE_NOISE_VALUES = ["clean", "crowd", "crowdbest", "expert", "distant", "weak", "llm"] - + if noise not in VALUE_NOISE_VALUES: - raise ValueError(f"Unsupported value for noise type argument. Got {noise}, expected one of {VALUE_NOISE_VALUES}!") - + raise ValueError( + f"Unsupported value for noise type argument. Got {noise}, expected one of {VALUE_NOISE_VALUES}!" + ) + self.base_path = flair.cache_root / "datasets" / "noisebench" if not base_path else Path(base_path) filename = "clean" if noise == "clean" else f"noise_{noise}" @@ -5270,16 +5271,13 @@ def __init__( if not all(files_exist): cached_path(f"{self.label_url}/{filename}.traindev", self.base_path / "annotations_only") cached_path(f"{self.label_url}/index.txt", self.base_path / "annotations_only") - + cleanconll_corpus = CLEANCONLL() self.cleanconll_base_path = flair.cache_root / "datasets" / cleanconll_corpus.__class__.__name__.lower() # create dataset files from index and train/test splits - self._generate_data_files( - filename, - cleanconll_corpus.__class__.__name__.lower() - ) + self._generate_data_files(filename, cleanconll_corpus.__class__.__name__.lower()) super().__init__( data_folder=self.base_path, @@ -5294,16 +5292,13 @@ def __init__( ) @staticmethod - def _read_column_file(filename: Union[str, Path]) -> list[list[str]]: - with open(filename, "r", errors="replace", encoding="utf-8") as file: + def _read_column_file(filename: Union[str, Path]) -> list[list[list[str]]]: + with open(filename, errors="replace", encoding="utf-8") as file: lines = file.readlines() all_sentences = [] sentence = [] for line in lines: - if "\t" in line.strip(): - stripped_line = line.strip().split("\t") - else: - stripped_line = line.strip().split(" ") + stripped_line = line.strip().split("\t") if "\t" in line.strip() else line.strip().split(" ") sentence.append(stripped_line) if line.strip() == "": @@ -5318,7 +5313,7 @@ def _read_column_file(filename: Union[str, Path]) -> list[list[str]]: return all_sentences @staticmethod - def _save_to_column_file(filename: Union[str, Path], sentences: list[list[str]]) -> None: + def _save_to_column_file(filename: Union[str, Path], sentences: list[list[list[str]]]) -> None: with open(filename, "w", encoding="utf-8") as f: for sentence in sentences: for token in sentence: @@ -5326,7 +5321,9 @@ def _save_to_column_file(filename: Union[str, Path], sentences: list[list[str]]) f.write("\n") f.write("\n") - def _create_train_dev_splits(self, filename: Path, all_sentences: list = None, datestring: str ="1996-08-24") -> None: + def _create_train_dev_splits( + self, filename: Path, all_sentences: Optional[list] = None, datestring: str = "1996-08-24" + ) -> None: if not all_sentences: all_sentences = self._read_column_file(filename) @@ -5356,7 +5353,9 @@ def _create_train_dev_splits(self, filename: Path, all_sentences: list = None, d train_sentences, ) - def _merge_tokens_labels(self, corpus: str, all_clean_sentences: list, token_indices: list) -> list[list[str]]: + def _merge_tokens_labels( + self, corpus: str, all_clean_sentences: list, token_indices: list + ) -> list[list[list[str]]]: # generate NoiseBench dataset variants, given CleanCoNLL, noisy label files and index file noisy_labels = self._read_column_file(self.base_path / "annotations_only" / f"{corpus}.traindev") @@ -5376,9 +5375,9 @@ def _merge_tokens_labels(self, corpus: str, all_clean_sentences: list, token_ind self._save_to_column_file(self.base_path / f"{corpus}.traindev", noisy_labels) return noisy_labels - def _generate_data_files(self, filename: Union[str, Path], origin_dataset_name: str) -> None: + def _generate_data_files(self, filename: str, origin_dataset_name: str) -> None: - with open(self.base_path / "annotations_only" / "index.txt", "r", encoding="utf-8") as index_file: + with open(self.base_path / "annotations_only" / "index.txt", encoding="utf-8") as index_file: token_indices = index_file.readlines() all_clean_sentences = self._read_column_file(self.cleanconll_base_path / f"{origin_dataset_name}.train")