Skip to content

Commit

Permalink
Add return type hints to util methods (#1754)
Browse files Browse the repository at this point in the history
* Add type hints to util methods

* fixup! Add type hints to util methods

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
zachschillaci27 and tomaarsen authored Dec 18, 2023
1 parent 1c396f8 commit aae5d40
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

logger = logging.getLogger(__name__)

def pytorch_cos_sim(a: Tensor, b: Tensor):
def pytorch_cos_sim(a: Tensor, b: Tensor) -> Tensor:
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
return cos_sim(a, b)

def cos_sim(a: Tensor, b: Tensor):
def cos_sim(a: Tensor, b: Tensor) -> Tensor:
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
Expand All @@ -53,7 +53,7 @@ def cos_sim(a: Tensor, b: Tensor):
return torch.mm(a_norm, b_norm.transpose(0, 1))


def dot_score(a: Tensor, b: Tensor):
def dot_score(a: Tensor, b: Tensor) -> Tensor:
"""
Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
Expand All @@ -74,7 +74,7 @@ def dot_score(a: Tensor, b: Tensor):
return torch.mm(a, b.transpose(0, 1))


def pairwise_dot_score(a: Tensor, b: Tensor):
def pairwise_dot_score(a: Tensor, b: Tensor) -> Tensor:
"""
Computes the pairwise dot-product dot_prod(a[i], b[i])
Expand All @@ -89,7 +89,7 @@ def pairwise_dot_score(a: Tensor, b: Tensor):
return (a * b).sum(dim=-1)


def pairwise_cos_sim(a: Tensor, b: Tensor):
def pairwise_cos_sim(a: Tensor, b: Tensor) -> Tensor:
"""
Computes the pairwise cossim cos_sim(a[i], b[i])
Expand All @@ -104,7 +104,7 @@ def pairwise_cos_sim(a: Tensor, b: Tensor):
return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b))


def normalize_embeddings(embeddings: Tensor):
def normalize_embeddings(embeddings: Tensor) -> Tensor:
"""
Normalizes the embeddings matrix, so that each sentence embedding has unit length
"""
Expand All @@ -116,7 +116,7 @@ def paraphrase_mining(model,
show_progress_bar: bool = False,
batch_size:int = 32,
*args,
**kwargs):
**kwargs) -> List[List[Union[float, int]]]:
"""
Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all
other sentences and returns a list with the pairs that have the highest cosine similarity score.
Expand Down Expand Up @@ -144,7 +144,7 @@ def paraphrase_mining_embeddings(embeddings: Tensor,
corpus_chunk_size: int = 100000,
max_pairs: int = 500000,
top_k: int = 100,
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim):
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim) -> List[List[Union[float, int]]]:
"""
Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all
other sentences and returns a list with the pairs that have the highest cosine similarity score.
Expand Down Expand Up @@ -202,7 +202,7 @@ def paraphrase_mining_embeddings(embeddings: Tensor,
return pairs_list


def information_retrieval(*args, **kwargs):
def information_retrieval(*args, **kwargs) -> List[List[Dict[str, Union[int, float]]]]:
"""This function is deprecated. Use semantic_search instead"""
return semantic_search(*args, **kwargs)

Expand All @@ -212,7 +212,7 @@ def semantic_search(query_embeddings: Tensor,
query_chunk_size: int = 100,
corpus_chunk_size: int = 500000,
top_k: int = 10,
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim):
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim) -> List[List[Dict[str, Union[int, float]]]]:
"""
This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.
Expand Down Expand Up @@ -276,7 +276,7 @@ def semantic_search(query_embeddings: Tensor,
return queries_result_list


def http_get(url, path):
def http_get(url, path) -> None:
"""
Downloads a URL to a given path on disc
"""
Expand Down Expand Up @@ -314,7 +314,7 @@ def batch_to_device(batch, target_device: device):



def fullname(o):
def fullname(o) -> str:
"""
Gives a full name (package_name.class_name) for a class / object in Python. Will
be used to load the correct classes from JSON files
Expand Down Expand Up @@ -349,7 +349,7 @@ def import_from_string(dotted_path):
raise ImportError(msg)


def community_detection(embeddings, threshold=0.75, min_community_size=10, batch_size=1024, show_progress_bar=False):
def community_detection(embeddings, threshold=0.75, min_community_size=10, batch_size=1024, show_progress_bar=False) -> List[List[int]]:
"""
Function for Fast Community Detection
Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).
Expand Down

0 comments on commit aae5d40

Please sign in to comment.