Skip to content

Commit

Permalink
fix black ruff and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Feb 4, 2024
1 parent efb1018 commit 44c1413
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
12 changes: 6 additions & 6 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,21 +599,21 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return self.__str__()

def add_label(self, typename: str, value: str, score: float = 1.0):
def add_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# therefore, labels get added only to the Sentence if it exists
if self.sentence:
super().add_label(typename=typename, value=value, score=score)
super().add_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.add_label(self, typename=typename, value=value, score=score)
DataPoint.add_label(self, typename=typename, value=value, score=score, **metadata)

def set_label(self, typename: str, value: str, score: float = 1.0):
def set_label(self, typename: str, value: str, score: float = 1.0, **metadata):
# The Token is a special _PartOfSentence in that it may be initialized without a Sentence.
# Therefore, labels get set only to the Sentence if it exists
if self.sentence:
super().set_label(typename=typename, value=value, score=score)
super().set_label(typename=typename, value=value, score=score, **metadata)
else:
DataPoint.set_label(self, typename=typename, value=value, score=score)
DataPoint.set_label(self, typename=typename, value=value, score=score, **metadata)

def to_dict(self, tag_type: Optional[str] = None):
return {
Expand Down
26 changes: 14 additions & 12 deletions flair/models/entity_mention_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,11 @@ def _get_state(self) -> Dict[str, Any]:
@classmethod
def _from_state(cls, state_dict: Dict[str, Any]) -> "EntityPreprocessor":
return cls(
preprocessor=None
if state_dict["preprocessor"] is None
else EntityPreprocessor._from_state(state_dict["preprocessor"]),
preprocessor=(
None
if state_dict["preprocessor"] is None
else EntityPreprocessor._from_state(state_dict["preprocessor"])
),
)


Expand Down Expand Up @@ -601,11 +603,7 @@ def bi_encoder(
ngram_range=(1, 2),
)

sparse_weight = (
sparse_weight
if model_name_or_path not in HYBRID_MODELS_SPARSE_WEIGHT
else HYBRID_MODELS_SPARSE_WEIGHT[model_name_or_path]
)
sparse_weight = HYBRID_MODELS_SPARSE_WEIGHT.get(model_name_or_path, sparse_weight)

return cls(
embeddings,
Expand Down Expand Up @@ -903,9 +901,11 @@ def predict(
for entity in entities_mentions:
data_points.append(entity.data_point)
mentions.append(
self.preprocessor.process_mention(entity.data_point.text, sentence)
if self.preprocessor is not None
else entity.data_point.text,
(
self.preprocessor.process_mention(entity.data_point.text, sentence)
if self.preprocessor is not None
else entity.data_point.text
),
)

# Retrieve top-k concept / entity candidates
Expand All @@ -915,7 +915,9 @@ def predict(
# Add a label annotation for each candidate
for data_point, mention_candidates in zip(data_points[i : i + batch_size], candidates):
for candidate_id, confidence in mention_candidates:
data_point.add_label(pred_label_type, candidate_id, confidence, name=self.dictionary[candidate_id].concept_name)
data_point.add_label(
pred_label_type, candidate_id, confidence, name=self.dictionary[candidate_id].concept_name
)

@staticmethod
def _fetch_model(model_name: str) -> str:
Expand Down

0 comments on commit 44c1413

Please sign in to comment.