Skip to content

Commit

Permalink
remove spacy dependency:
Browse files Browse the repository at this point in the history
- refactorings of dataset transformations for NER
- corresponding adjustments of tests
- remove dependency from requirements.txt
- remove of unnecessary function replace_token_labels - can be done with dataset transformation function
- adjustments of tutorials
  • Loading branch information
whoisjones committed Oct 27, 2023
1 parent eae0669 commit 5c300a2
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 117 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
datasets
farm-haystack>=1.18.0
spacy
loguru
3 changes: 1 addition & 2 deletions src/fabricator/dataset_transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
"replace_class_labels",
"convert_token_labels_to_spans",
"convert_spans_to_token_labels",
"replace_token_labels",
]

from .question_answering import preprocess_squad_format, postprocess_squad_format, calculate_answer_start
from .text_classification import convert_label_ids_to_texts, get_labels_from_dataset, replace_class_labels
from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels, replace_token_labels
from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels
222 changes: 133 additions & 89 deletions src/fabricator/dataset_transformations/token_classification.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,183 @@
import re
from typing import Dict, List, Tuple
from collections import defaultdict
from typing import Dict, List, Tuple, Union
from datasets import Dataset, Sequence

from tqdm import tqdm
import spacy
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.training import iob_to_biluo, biluo_tags_to_offsets, offsets_to_biluo_tags, biluo_to_iob
from loguru import logger

# These are fixed for encoding the prompt and decoding the output of the LLM
LABEL_SEPARATOR = "\n"
LABEL2ENTITY_SEPARATOR = "->"
ENTITY_SEPARATOR = ", "
SPAN_ANNOTATION_TEMPLATE = "{entity} is {label} entity."
SPAN_ANNOTATION_REGEX = r'(.+) is (.+) entity\.'


def convert_token_labels_to_spans(
dataset: Dataset, token_column: str, label_column: str, expanded_label_mapping: Dict = None
) -> Tuple[Dataset, List[str]]:
dataset: Dataset,
token_column: str,
label_column: str,
expanded_label_mapping: Dict = None,
return_label_options: bool = False
) -> Union[Dataset, Tuple[Dataset, List[str]]]:
"""Converts token level labels to spans. Useful for NER tasks to prompt the LLM with natural language labels.
Args:
dataset (Dataset): huggingface Dataset with token level labels
token_column (str): name of the column with the tokens
label_column (str): name of the column with the token level labels
expanded_label_mapping (Dict): mapping from label ids to label names. Defaults to None.
return_label_options (bool): whether to return a list of all possible annotations of the provided dataset
Returns:
Tuple[Dataset, List[str]]: huggingface Dataset with span labels and list of possible labels for the prompt
"""
if expanded_label_mapping:
if not len(expanded_label_mapping) == len(dataset.features[label_column].feature.names):
raise ValueError(
f"Length of expanded label mapping and original number of labels in dataset do not match.\n"
f"Original labels: {dataset.features[label_column].feature.names}"
f"Expanded labels: {list(expanded_label_mapping.values())}"
)
id2label = expanded_label_mapping
elif isinstance(dataset.features[label_column], Sequence):
id2label = dict(enumerate(dataset.features[label_column].feature.names))
else:
raise ValueError("Labels must be a Sequence feature or expanded_label_mapping must be provided.")

new_label_column = f"{label_column}_natural_language"
label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()})
if "O" in label_options:
label_options.remove("O")
span_column = "span_annotations"

def labels_to_spans(examples):
bio_tags = [id2label[label] for label in examples[label_column]]
bilou_tags = iob_to_biluo(bio_tags)
doc = Doc(Vocab(), words=examples[token_column])
offsets = biluo_tags_to_offsets(doc, bilou_tags)
def labels_to_spans(example):
span_annotations = [id2label.get(label).replace("B-", "").replace("I-", "") for label in example[label_column]]

span_labels = defaultdict(list)
for start, end, label in offsets:
span_labels[label].append(doc.text[start:end])
annotations_for_prompt = ""

examples[token_column] = doc.text
span_labels = {k: ENTITY_SEPARATOR.join(v) for k, v in span_labels.items()}
examples[new_label_column] = LABEL_SEPARATOR.join(
[f"{k} {LABEL2ENTITY_SEPARATOR} {v}" for k, v in span_labels.items()]
)
return examples
current_entity = None
current_entity_type = None
for idx, span_annotation in enumerate(span_annotations):
if span_annotation == "O":
if current_entity is not None:
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
label=current_entity_type) + "\n"
current_entity = None
current_entity_type = None
continue
if current_entity is None:
current_entity = example[token_column][idx]
current_entity_type = span_annotation
continue
if current_entity_type == span_annotation:
current_entity += " " + example[token_column][idx]
else:
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
label=current_entity_type) + "\n"
current_entity = example[token_column][idx]
current_entity_type = span_annotation

if current_entity is not None:
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
label=current_entity_type) + "\n"

example[token_column] = " ".join(example[token_column])
example[span_column] = annotations_for_prompt.rstrip("\n")
return example

dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(span_column, label_column)

if return_label_options:
# Spans have implicit BIO format, so sequences come in BIO format, we can ignore it
label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()})

# Ignore "outside" tokens
if "O" in label_options:
label_options.remove("O")

dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(new_label_column, label_column)
return dataset, label_options

return dataset, label_options
return dataset


def convert_spans_to_token_labels(dataset, token_column, label_column, id2label: Dict) -> Dataset:
"""Converts span level labels to token level labels. This is useful for NER tasks to decode the output of the LLM.
def convert_spans_to_token_labels(
dataset: Dataset,
token_column: str,
label_column: str,
id2label: Dict,
annotate_identical_words: bool = False
) -> Dataset:
"""Converts span level labels to token level labels.
First, the function extracts all entities with its annotated types.
Second, if annotations are present, the function converts them to a tag sequence in BIO format.
If not present, simply return tag sequence of O-tokens.
This is useful for NER tasks to decode the output of the LLM.
Args:
dataset (Dataset): huggingface Dataset with span level labels
token_column (str): name of the column with the tokens
label_column (str): name of the column with the span level labels
id2label (Dict): mapping from label ids to label names
annotate_identical_words (bool): whether to annotate all identical words in a sentence with a found entity
type
Returns:
Dataset: huggingface Dataset with token level labels in BIO format
"""
new_label_column = f"{label_column}_tags"
label2id = {v: k for k, v in id2label.items()}
labels_no_bio = set([label.replace("B-", "").replace("I-", "") for label in id2label.values()])
nlp = spacy.blank("en")

def labels_to_spans(examples):
texts = examples[token_column]
str_labels = examples[label_column]
# goal list of lists of tuples (start, end, label)

tokens = []
bio_tags = []
for text, str_label in tqdm(zip(texts, str_labels), desc="Converting spans to token labels"):
spans = []

if not str_label:
bio_tags.append([])
tokens.append([])
continue

try:
for label_and_entities in str_label.split(LABEL_SEPARATOR):
label, entities = label_and_entities.split(LABEL2ENTITY_SEPARATOR)
label = label.strip()
if label not in labels_no_bio:
continue
entities = [entity.strip().lower() for entity in entities.split(ENTITY_SEPARATOR)]
for entity in set(entities):
pattern = re.compile(r'\b' + re.escape(entity) + r'\b')
matches = pattern.finditer(text.lower())
for start, end in [(match.start(), match.end()) for match in matches]:
spans.append((start, end, label))
except ValueError:
bio_tags.append([])
tokens.append([])
continue

doc = nlp(text)

try:
tags = [tag if tag != "-" else "O" for tag in biluo_to_iob(offsets_to_biluo_tags(doc, spans))]
words = [word.text for word in doc]
if not len(tags) == len(words) or len(tags) == 0 or len(words) == 0:
tags = []
words = []
bio_tags.append(tags)
tokens.append(words)
except ValueError:
bio_tags.append([])
tokens.append([])
continue
new_label_column = "sequence_tags"
lower_label2id = {label.lower(): idx for idx, label in id2label.items()}

def labels_to_spans(example):
span_annotations = example[label_column].split("\n")

ner_tag_tuples = []

for span_annotation in span_annotations:
matches = re.match(SPAN_ANNOTATION_REGEX, span_annotation)
if matches:
matched_entity = matches.group(1)
matched_label = matches.group(2)

span_tokens = matched_entity.split(" ")
span_labels = ["B-" + matched_label if idx == 0 else "B-" + matched_label.lower()
for idx, _ in enumerate(span_tokens)]

for token, label in zip(span_tokens, span_labels):
label_id = lower_label2id.get(label.lower())
if label_id is None:
logger.info(f"Entity {token} with label {label} is not in id2label: {id2label}.")
else:
ner_tag_tuples.append((token, label_id))
else:
pass

if ner_tag_tuples:
lower_tokens = example[token_column].lower().split(" ")
# initialize all tokens with O type
ner_tags = [0] * len(lower_tokens)
for reference_token, entity_type_id in ner_tag_tuples:
if lower_tokens.count(reference_token.lower()) == 0:
logger.info(
f"Entity {reference_token} is not found or occurs more than once: {lower_tokens}. "
f"Thus, setting label to O."
)
elif lower_tokens.count(reference_token.lower()) > 1:
if annotate_identical_words:
insert_at_idxs = [index for index, value in enumerate(lower_tokens)
if value == reference_token.lower()]
for insert_at_idx in insert_at_idxs:
ner_tags[insert_at_idx] = entity_type_id
else:
logger.info(
f"Entity {reference_token} occurs more than once: {lower_tokens}. "
f"Thus, setting label to O."
)
else:
insert_at_idx = lower_tokens.index(reference_token.lower())
ner_tags[insert_at_idx] = entity_type_id
else:
ner_tags = [0] * len(example[token_column].split(" "))

examples[token_column] = tokens
examples[new_label_column] = [[label2id[tag] for tag in tags] for tags in bio_tags]
example[token_column] = example[token_column].split(" ")
example[new_label_column] = ner_tags

return examples
return example

dataset = (
dataset.map(labels_to_spans, batched=True)
dataset.map(labels_to_spans)
.remove_columns(label_column)
.rename_column(new_label_column, label_column)
)
Expand Down
43 changes: 24 additions & 19 deletions tests/test_dataset_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,27 @@ class TestTransformationsTokenClassification(unittest.TestCase):
"""Testcase for TokenLabelTransformations"""

def setUp(self) -> None:
self.dataset = load_dataset("conll2003", split="train")
self.dataset = load_dataset("conll2003", split="train").select(range(150))

def test_bio_tokens_to_spans(self):
"""Test transformation output only (BIO to spans)"""
dataset, label_options = convert_token_labels_to_spans(
self.dataset, "tokens", "ner_tags"
self.dataset, "tokens", "ner_tags", return_label_options=True
)
self.assertEqual(len(label_options), 4)
self.assertEqual(type(dataset[0]["ner_tags"]), str)
self.assertNotEqual(type(dataset[0]["ner_tags"]), int)
labels = [
spans.split(LABEL2ENTITY_SEPARATOR, 1)[0].strip()
for spans in dataset[0]["ner_tags"].split(LABEL_SEPARATOR)
]
for label in labels:
self.assertIn(label, label_options)
spans = [span for span in dataset[0]["ner_tags"].split("\n")]
for span in spans:
self.assertTrue(any([label in span for label in label_options]))

def test_formatting_with_span_labels(self):
"""Test formatting with span labels"""
dataset, label_options = convert_token_labels_to_spans(
dataset=self.dataset,
token_column="tokens",
label_column="ner_tags",
return_label_options=True
)
fewshot_examples = dataset.select([1, 2, 3])
prompt = BasePrompt(
Expand All @@ -115,25 +113,31 @@ def test_formatting_with_span_labels(self):
label_options=label_options,
)
raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
self.assertIn("PER -> Peter Blackburn", raw_prompt)
self.assertIn("LOC -> BRUSSELS", raw_prompt)
self.assertIn("Peter Blackburn is PER entity.", raw_prompt)
self.assertIn("BRUSSELS is LOC entity.", raw_prompt)
for label in label_options:
self.assertIn(label, raw_prompt)

def test_expanded_textual_labels(self):
"""Test formatting with expanded textual labels"""
extended_mapping = {"PER": "person", "LOC": "location", "ORG": "organization", "MISC": "misceallaneous"}
id2label = replace_token_labels(dict(enumerate(self.dataset.features["ner_tags"].feature.names)), extended_mapping)
self.assertIn("B-location", id2label.values())
self.assertIn("I-person", id2label.values())
self.assertNotIn("B-LOC", id2label.values())
self.assertNotIn("I-MISC", id2label.values())
expanded_label_mapping = {
0: "O",
1: "B-person",
2: "I-person",
3: "B-location",
4: "I-location",
5: "B-organization",
6: "I-organization",
7: "B-miscellaneous",
8: "I-miscellaneous",
}

dataset, label_options = convert_token_labels_to_spans(
dataset=self.dataset,
token_column="tokens",
label_column="ner_tags",
expanded_label_mapping=id2label
expanded_label_mapping=expanded_label_mapping,
return_label_options=True
)
fewshot_examples = dataset.select([1, 2, 3])
prompt = BasePrompt(
Expand All @@ -143,7 +147,7 @@ def test_expanded_textual_labels(self):
label_options=label_options,
)
raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
self.assertIn("person -> Peter Blackburn", raw_prompt)
self.assertIn("Peter Blackburn is person entity.", raw_prompt)
self.assertNotIn("PER", raw_prompt)
for label in label_options:
self.assertIn(label, raw_prompt)
Expand All @@ -154,9 +158,10 @@ def test_textual_labels_to_label_ids(self):
dataset=self.dataset,
token_column="tokens",
label_column="ner_tags",
return_label_options=True
)
id2label = dict(enumerate(self.dataset.features["ner_tags"].feature.names))
self.assertEqual(dataset[0]["ner_tags"], "ORG -> EU\nMISC -> German, British")
self.assertEqual(dataset[0]["ner_tags"], "EU is ORG entity.\nGerman is MISC entity.\nBritish is MISC entity.")
dataset = dataset.select(range(10))
dataset = convert_spans_to_token_labels(
dataset=dataset,
Expand Down
Loading

0 comments on commit 5c300a2

Please sign in to comment.