Skip to content

Commit

Permalink
Merge branch 'master' into gh-3462/load_best_model
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname authored Jun 28, 2024
2 parents 60b9863 + ca1b90b commit 64955ac
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@
UD_ARABIC,
UD_ARMENIAN,
UD_BASQUE,
UD_BAVARIAN_MAIBAAM,
UD_BELARUSIAN,
UD_BULGARIAN,
UD_BURYAT,
Expand Down Expand Up @@ -536,6 +537,7 @@
"UD_ARABIC",
"UD_ARMENIAN",
"UD_BASQUE",
"UD_BAVARIAN_MAIBAAM",
"UD_BELARUSIAN",
"UD_BULGARIAN",
"UD_BURYAT",
Expand Down
31 changes: 30 additions & 1 deletion flair/datasets/treebanks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def __init__(
dev_file, test_file, train_file = find_train_dev_test_files(data_folder, dev_file, test_file, train_file)

# get train data
train = UniversalDependenciesDataset(train_file, in_memory=in_memory, split_multiwords=split_multiwords)
train = (
UniversalDependenciesDataset(train_file, in_memory=in_memory, split_multiwords=split_multiwords)
if train_file is not None
else None
)

# get test data
test = (
Expand Down Expand Up @@ -1509,6 +1513,7 @@ def __init__(

# download data if necessary
web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Old_French-SRCMF/{revision}"

cached_path(f"{web_path}/fro_profiterole-ud-dev.conllu", Path("datasets") / dataset_name)
cached_path(f"{web_path}/fro_profiterole-ud-test.conllu", Path("datasets") / dataset_name)
cached_path(f"{web_path}/fro_profiterole-ud-train.conllu", Path("datasets") / dataset_name)
Expand Down Expand Up @@ -1664,3 +1669,27 @@ def __init__(
cached_path(f"{web_path}/lt_alksnis-ud-train.conllu", Path("datasets") / dataset_name)

super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords)


class UD_BAVARIAN_MAIBAAM(UniversalDependenciesCorpus):
def __init__(
self,
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
split_multiwords: bool = True,
revision: str = "dev",
) -> None:
base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path)

# this dataset name
dataset_name = self.__class__.__name__.lower()

# default dataset folder is the cache root

data_folder = base_path / dataset_name

# download data if necessary
web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Bavarian-MaiBaam/{revision}"
cached_path(f"{web_path}/bar_maibaam-ud-test.conllu", Path("datasets") / dataset_name)

super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords)

0 comments on commit 64955ac

Please sign in to comment.