Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entity Mention Linker #3388

Merged
merged 58 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a01e060
Initial version (already adapted to recent Flair API changes)
Mar 14, 2023
22bc93a
Revise mention text pre-processing: define general interface and adap…
Mar 14, 2023
acf5fb6
Refactor entity linking model structure
Mar 15, 2023
56c89ba
Update documentation
Mar 22, 2023
51fe951
Introduce separate methods for pre-processing (1) entity mentions fro…
Mar 23, 2023
19f74fb
Fix formatting
alanakbik Apr 21, 2023
e5342d5
feat(test): biomedical entity linking
Apr 26, 2023
4ef5924
fix(test): hold on w/ automatic tests for now
Apr 26, 2023
abc42b5
fix(bionel): start major refactoring
Apr 26, 2023
bdc3e8a
fix(bionel): major refactor
Apr 27, 2023
48e8ae7
fix(bionel): assign entity type
May 2, 2023
e6b57eb
fix(biencoder): set sparse encoder and weight
May 2, 2023
0d3cec2
fix(bionel): address comments
May 11, 2023
99a109f
fix(candidate_generator): container for search result
May 12, 2023
301988e
fix(predict): default annotation layer iff not provided by use
May 19, 2023
c14d6ce
fix(label): scores can be >= or <=
May 19, 2023
c66789a
fix(candidate): parametrize database name
May 19, 2023
70c0c7d
feat(candidate_generator): cache sparse encoder
May 22, 2023
37a2458
fix(candidate_generator): minor improvements
May 23, 2023
bc801fc
feat(linking_candidate): pretty print
May 24, 2023
3677658
fix(candidate_generator): check sparse encoder for sparse search
May 24, 2023
414b5a8
feat(candidate_generator): add sparse index
Jun 1, 2023
1783eab
fix(candidate_generator): KISS: sparse search w/ scipy sparse matrices
Jun 2, 2023
8c908ba
Minor update to comments and documentation
Jul 12, 2023
c6273e6
Fix tests and type annotations
Jul 12, 2023
9a3fff6
code format and fix ruff & most mypy errors.
Aug 28, 2023
fd9507f
refine interface of BiomedicalEntityLinker
Sep 4, 2023
a374040
refactor knowledgebase datasets
Sep 4, 2023
2c12f14
refactor CandidateSearchIndex
Sep 11, 2023
e398213
fix rebase errors
Sep 18, 2023
51a1a57
WIP: add save functionality
Sep 18, 2023
ce85e3d
add load & save functionality
Oct 2, 2023
750ba7d
fix naming
helpmefindaname Dec 23, 2023
c5fa9c9
add hf model download
helpmefindaname Dec 24, 2023
94db5b3
fix ruff & mypy errors
helpmefindaname Dec 24, 2023
1d55fec
fix entity linking test
helpmefindaname Dec 24, 2023
a4a27dc
fixed selection of knowledge base identifiers for entity_mention_linking
Jan 2, 2024
6689297
fix(dictionary): corrections in file parsing
Jan 5, 2024
c2c53a9
fix: preprocessing, candidate generator, linker
Jan 11, 2024
1efe956
feat(tests): test preprocessing
Jan 11, 2024
5da7494
Minor fix in pre-processing pipeline
Jan 11, 2024
4cd86b0
Add support for label and entity type definition + fix tests
Jan 12, 2024
aaa5f39
fix: formatting and type checking
Jan 12, 2024
def9905
add batchsize to prediction instead of only embedding to reduce memor…
helpmefindaname Jan 21, 2024
03092c2
add evaluation function
helpmefindaname Jan 21, 2024
bf88d0a
fix(linker): extraction of entity mentions in predict
Jan 26, 2024
7b824a1
fix(predict): ensure mentions extraction works with legacy classifier
Jan 29, 2024
c8cc120
chore: update tests
Jan 29, 2024
94aaca1
fix(tests): normalized entity type name
Jan 29, 2024
75ea402
fix(logging): deprecated logger.warn
Jan 29, 2024
d9805be
improve typing and run black
Jan 12, 2024
8ab67d2
add datasets for nel-bioner evaluation
Feb 2, 2024
654ed8d
mark heavy test as integration test
Feb 2, 2024
efb1018
add metadata to labels for cnadidate names
Feb 2, 2024
44c1413
fix black ruff and mypy
helpmefindaname Feb 4, 2024
d7fa7dd
make test more memory efficient by only loading the smallest model
helpmefindaname Feb 6, 2024
ba833f0
chore(docs): add docstrings fro datasets
Feb 7, 2024
44d73a2
fix(bigbio): better naming & fix ruff
Feb 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ venv.bak/

resources/taggers/
regression_train/
/doc_build/
19 changes: 19 additions & 0 deletions flair/class_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import inspect
from typing import Iterable, Optional, Type, TypeVar

T = TypeVar("T")


def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
for subclass in cls.__subclasses__():
yield from get_non_abstract_subclasses(subclass)
if inspect.isabstract(subclass):
continue
yield subclass


def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T]:
for sub_cls in get_non_abstract_subclasses(cls):
if sub_cls.__name__ == cls_name:
return sub_cls
raise ValueError(f"Could not find any class with name '{cls_name}'")
142 changes: 102 additions & 40 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ class Label:
Default value for the score is 1.0.
"""

def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0) -> None:
def __init__(self, data_point: "DataPoint", value: str, score: float = 1.0, **metadata) -> None:
self._value = value
self._score = score
self.data_point: DataPoint = data_point
self.metadata = metadata
super().__init__()

def set_value(self, value: str, score: float = 1.0):
Expand All @@ -235,14 +236,14 @@ def to_dict(self):
return {"value": self.value, "confidence": self.score}

def __str__(self) -> str:
return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value} ({round(self._score, 4)})"
return f"{self.data_point.unlabeled_identifier}{flair._arrow}{self._value}{self.metadata_str} ({round(self._score, 4)})"

@property
def shortstring(self):
return f'"{self.data_point.text}"/{self._value}'

def __repr__(self) -> str:
return f"'{self.data_point.unlabeled_identifier}'/'{self._value}' ({round(self._score, 4)})"
return f"'{self.data_point.unlabeled_identifier}'/'{self._value}'{self.metadata_str} ({round(self._score, 4)})"

def __eq__(self, other):
return self.value == other.value and self.score == other.score and self.data_point == other.data_point
Expand All @@ -253,6 +254,13 @@ def __hash__(self):
def __lt__(self, other):
return self.data_point < other.data_point

@property
def metadata_str(self) -> str:
if not self.metadata:
return ""
rep = "/".join(f"{k}={v}" for k, v in self.metadata.items())
return f"/{rep}"

@property
def labeled_identifier(self):
return f"{self.data_point.unlabeled_identifier}/{self.value}"
Expand Down Expand Up @@ -336,16 +344,18 @@ def get_metadata(self, key: str) -> typing.Any:
def has_metadata(self, key: str) -> bool:
return key in self._metadata

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
label = Label(self, value, score, **metadata)

if typename not in self.annotation_layers:
self.annotation_layers[typename] = [Label(self, value, score)]
self.annotation_layers[typename] = [label]
else:
self.annotation_layers[typename].append(Label(self, value, score))
self.annotation_layers[typename].append(label)

return self

def set_label(self, typename: str, value: str, score: float = 1.0):
self.annotation_layers[typename] = [Label(self, value, score)]
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
self.annotation_layers[typename] = [Label(self, value, score, **metadata)]
return self

def remove_labels(self, typename: str):
Expand Down Expand Up @@ -375,28 +385,25 @@ def labels(self) -> List[Label]:
def unlabeled_identifier(self):
raise NotImplementedError

def _printout_labels(self, main_label=None, add_score: bool = True):
def _printout_labels(self, main_label=None, add_score: bool = True, add_metadata: bool = True):
all_labels = []
keys = [main_label] if main_label is not None else self.annotation_layers.keys()
if add_score:
for key in keys:
all_labels.extend(
[
f"{label.value} ({round(label.score, 4)})"
for label in self.get_labels(key)
if label.data_point == self
]
)
labels = "; ".join(all_labels)
if labels != "":
labels = flair._arrow + labels
else:
for key in keys:
all_labels.extend([f"{label.value}" for label in self.get_labels(key) if label.data_point == self])
labels = "/".join(all_labels)
if labels != "":
labels = "/" + labels
return labels

sep = "; " if add_score else "/"
sent_sep = flair._arrow if add_score else "/"
for key in keys:
for label in self.get_labels(key):
if label.data_point is not self:
continue
value = label.value
if add_metadata:
value = f"{value}{label.metadata_str}"
if add_score:
value = f"{value} ({label.score:.04f})"
all_labels.append(value)
if not all_labels:
return ""
return sent_sep + sep.join(all_labels)

def __str__(self) -> str:
return self.unlabeled_identifier + self._printout_labels()
Expand Down Expand Up @@ -431,6 +438,61 @@ def __len__(self) -> int:
raise NotImplementedError


class EntityCandidate:
"""A Concept as part of a knowledgebase or ontology."""

def __init__(
self,
concept_id: str,
concept_name: str,
database_name: str,
additional_ids: Optional[List[str]] = None,
synonyms: Optional[List[str]] = None,
description: Optional[str] = None,
):
"""A Concept as part of a knowledgebase or ontology.

Args:
concept_id: Identifier of the concept from the knowledgebase / ontology
concept_name: (Canonical) name of the concept from the knowledgebase / ontology
additional_ids: List of additional identifiers for the concept / entity in the KB / ontology
database_name: Name of the knowledgebase / ontology
synonyms: A list of synonyms for this entry
description: A description about the Concept to describe
"""
self.concept_id = concept_id
self.concept_name = concept_name
self.database_name = database_name
self.description = description
if additional_ids is None:
self.additional_ids = []
else:
self.additional_ids = additional_ids
if synonyms is None:
self.synonyms = []
else:
self.synonyms = synonyms

def __str__(self) -> str:
string = f"EntityLinkingCandidate: {self.database_name}:{self.concept_id} - {self.concept_name}"
if self.additional_ids:
string += f" - {'|'.join(self.additional_ids)}"
return string

def __repr__(self) -> str:
return str(self)

def to_dict(self) -> Dict[str, typing.Any]:
return {
"concept_id": self.concept_id,
"concept_name": self.concept_name,
"database_name": self.database_name,
"additional_ids": self.additional_ids,
"synonyms": self.synonyms,
"description": self.description,
}


DT = typing.TypeVar("DT", bound=DataPoint)
DT2 = typing.TypeVar("DT2", bound=DataPoint)

Expand All @@ -440,18 +502,18 @@ def __init__(self, sentence) -> None:
super().__init__()
self.sentence: Sentence = sentence

def add_label(self, typename: str, value: str, score: float = 1.0):
super().add_label(typename, value, score)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score))
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
super().add_label(typename, value, score, **metadata)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata))

def set_label(self, typename: str, value: str, score: float = 1.0):
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
if len(self.annotation_layers.get(typename, [])) > 0:
# First we remove any existing labels for this PartOfSentence in self.sentence
self.sentence.annotation_layers[typename] = [
label for label in self.sentence.annotation_layers.get(typename, []) if label.data_point != self
]
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score))
super().set_label(typename, value, score)
self.sentence.annotation_layers.setdefault(typename, []).append(Label(self, value, score, **metadata))
super().set_label(typename, value, score, **metadata)
return self

def remove_labels(self, typename: str):
Expand Down Expand Up @@ -537,21 +599,21 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return self.__str__()

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# therefore, labels get added only to the Sentence if it exists
if self.sentence:
super().add_label(typename=typename, value=value, score=score)
super().add_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.add_label(self, typename=typename, value=value, score=score)
DataPoint.add_label(self, typename=typename, value=value, score=score, **metadata)

def set_label(self, typename: str, value: str, score: float = 1.0):
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# Therefore, labels get set only to the Sentence if it exists
if self.sentence:
super().set_label(typename=typename, value=value, score=score)
super().set_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.set_label(self, typename=typename, value=value, score=score)
DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata)

def to_dict(self, tag_type: Optional[str] = None):
return {
Expand Down
12 changes: 12 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@
# word sense disambiguation
# Expose all entity linking datasets
from .entity_linking import (
CTD_CHEMICALS_DICTIONARY,
CTD_DISEASES_DICTIONARY,
NCBI_GENE_HUMAN_DICTIONARY,
NCBI_TAXONOMY_DICTIONARY,
NEL_ENGLISH_AIDA,
NEL_ENGLISH_AQUAINT,
NEL_ENGLISH_IITB,
Expand All @@ -147,6 +151,8 @@
WSD_UFSAC,
WSD_WORDNET_GLOSS_TAGGED,
ZELDA,
EntityLinkingDictionary,
HunerEntityLinkingDictionary,
)

# Expose all relation extraction datasets
Expand Down Expand Up @@ -315,6 +321,7 @@
"SentenceDataset",
"MongoDataset",
"StringDataset",
"EntityLinkingDictionary",
"AGNEWS",
"ANAT_EM",
"AZDZ",
Expand Down Expand Up @@ -342,6 +349,7 @@
"FSU",
"GELLUS",
"GPRO",
"HunerEntityLinkingDictionary",
"HUNER_CELL_LINE",
"HUNER_CELL_LINE_CELL_FINDER",
"HUNER_CELL_LINE_CLL",
Expand Down Expand Up @@ -390,6 +398,10 @@
"LINNEAUS",
"LOCTEXT",
"MIRNA",
"NCBI_GENE_HUMAN_DICTIONARY",
"NCBI_TAXONOMY_DICTIONARY",
"CTD_DISEASES_DICTIONARY",
"CTD_CHEMICALS_DICTIONARY",
"NCBI_DISEASE",
"ONTONOTES",
"OSIRIS",
Expand Down
Loading
Loading