Skip to content

Commit

Permalink
Optimize speed and gpu usage for embed
Browse files Browse the repository at this point in the history
torch.cat needs double the memory since it persists original tensors as well as newly created concatenated tensors.
  • Loading branch information
harshitgupta412 authored and sidjha1 committed Oct 13, 2024
1 parent e2c48f0 commit 37a7056
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions lotus/models/e5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,22 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:
kwargs = {**self.kwargs, **kwargs}

batch_size = kwargs.get("batch_size", self.batch_size)
embeddings = []
for i, batch_start in enumerate(tqdm(range(0, len(docs), batch_size))):
batch = docs[batch_start : batch_start + batch_size]

with torch.no_grad():

# Calculating the embedding dimension
total_docs = len(docs)
first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True)
embed_dim = self.model(**first_batch).last_hidden_state.size(-1)

# Pre-allocate a tensor for all embeddings
embeddings = torch.empty((total_docs, embed_dim), device=self.device)
# Processing batches
with torch.inference_mode(): # Slightly faster than torch.no_grad() for inference
for i, batch_start in enumerate(tqdm(range(0, total_docs, batch_size))):
batch = docs[batch_start : batch_start + batch_size]
batch_dict = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)

outputs = self.model(**batch_dict)
batch_embeddings = self.average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
embeddings.append(batch_embeddings)

embeddings = torch.cat(embeddings, dim=0)
embeddings[batch_start : batch_start + batch_size] = batch_embeddings
if kwargs["normalize"]:
embeddings = F.normalize(embeddings, p=2, dim=1)

Expand Down

0 comments on commit 37a7056

Please sign in to comment.