Skip to content

Commit

Permalink
Make ruff happy
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Dec 17, 2023
1 parent 6382c47 commit d0441d7
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,9 @@ def __init__(


class AGNEWS(ClassificationCorpus):
"""The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics
(World, Sports, Business, Sci/Tech).
"""The AG's News Topic Classification Corpus, classifying news into 4 coarse-grained topics.
Labels: World, Sports, Business, Sci/Tech.
"""

def __init__(
Expand All @@ -920,6 +921,7 @@ def __init__(
**corpusargs,
):
"""Instantiates AGNews Classification Corpus with 4 classes.
:param base_path: Provide this only if you store the AGNEWS corpus in a specific folder, otherwise use default.
:param tokenizer: Custom tokenizer to use (default is SpaceTokenizer)
:param memory_mode: Set to 'partial' by default. Can also be 'full' or 'none'.
Expand Down Expand Up @@ -953,22 +955,23 @@ def __init__(
original_filenames = original_filenames[:-1]
if not data_file.is_file():
for original_filename, new_filename in zip(original_filenames, new_filenames):
with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp:
with open(data_folder / new_filename, "w", encoding="utf-8") as write_fp:
csv_reader = csv.reader(
open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True
)
for id_, row in enumerate(csv_reader):
label, title, description = row
# Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
# Re-map to [0, 1, 2, 3].
label = int(label) - 1
text = " ".join((title, description))
with open(data_folder / "original" / original_filename, encoding="utf-8") as open_fp, open(
data_folder / new_filename, "w", encoding="utf-8"
) as write_fp:
csv_reader = csv.reader(
open_fp, quotechar='"', delimiter=",", quoting=csv.QUOTE_ALL, skipinitialspace=True
)
for id_, row in enumerate(csv_reader):
label, title, description = row
# Original labels are [1, 2, 3, 4] -> ['World', 'Sports', 'Business', 'Sci/Tech']
# Re-map to [0, 1, 2, 3].
label = int(label) - 1
text = " ".join((title, description))

new_label = "__label__"
new_label += label_dict[label]
new_label = "__label__"
new_label += label_dict[label]

write_fp.write(f"{new_label} {text}\n")
write_fp.write(f"{new_label} {text}\n")

super().__init__(data_folder, label_type="topic", tokenizer=tokenizer, memory_mode=memory_mode, **corpusargs)

Expand Down

0 comments on commit d0441d7

Please sign in to comment.