From 37a705689b2dba8748465a8867302bc2be60d0d5 Mon Sep 17 00:00:00 2001 From: Harshit Gupta <59705530+harshitgupta412@users.noreply.github.com> Date: Sat, 12 Oct 2024 20:28:51 -0700 Subject: [PATCH] Optimize speed and gpu usage for embed torch.cat needs double the memory since it persists original tensors as well as newly created concatenated tensors. --- lotus/models/e5_model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/lotus/models/e5_model.py b/lotus/models/e5_model.py index 3c93c9d2..c7038ad8 100644 --- a/lotus/models/e5_model.py +++ b/lotus/models/e5_model.py @@ -58,18 +58,22 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray: kwargs = {**self.kwargs, **kwargs} batch_size = kwargs.get("batch_size", self.batch_size) - embeddings = [] - for i, batch_start in enumerate(tqdm(range(0, len(docs), batch_size))): - batch = docs[batch_start : batch_start + batch_size] - - with torch.no_grad(): + + # Calculating the embedding dimension + total_docs = len(docs) + first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True) + embed_dim = self.model(**first_batch).last_hidden_state.size(-1) + + # Pre-allocate a tensor for all embeddings + embeddings = torch.empty((total_docs, embed_dim), device=self.device) + # Processing batches + with torch.inference_mode(): # Slightly faster than torch.no_grad() for inference + for i, batch_start in enumerate(tqdm(range(0, total_docs, batch_size))): + batch = docs[batch_start : batch_start + batch_size] batch_dict = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) - outputs = self.model(**batch_dict) batch_embeddings = self.average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) - embeddings.append(batch_embeddings) - - embeddings = torch.cat(embeddings, dim=0) + embeddings[batch_start : batch_start + batch_size] = batch_embeddings if kwargs["normalize"]: embeddings = F.normalize(embeddings, p=2, dim=1)