Skip to content

Commit

Permalink
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sentence_transformers/model_card_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def model_card_get_pooling_function(pooling_mode):
# Max Pooling - Take the max value over time for every dimension.
def max_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
return torch.max(token_embeddings, 1)[0]
""",
Expand All @@ -142,7 +142,7 @@ def max_pooling(model_output, attention_mask):
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
""",
)
Expand Down

0 comments on commit 07fac06

Please sign in to comment.