Skip to content

Commit

Permalink
hide query
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 10, 2024
1 parent 37c6b44 commit 3b0a133
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions weave/scorers/relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,21 +303,26 @@ def _score_document(
return_tensors="pt",
return_special_tokens_mask=True
)

model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

special_tokens_mask = model_inputs.pop("special_tokens_mask")
combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool())

combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()).cpu().numpy().flatten()
# we should mask the query up to the sep token,
# on the combined mask we have to search for the first False
# TODO: Check that this is now wrong
false_indices = np.where(~combined_mask)[0]
start = false_indices[0]
end = false_indices[1]
combined_mask[start:end] = False
results = self._model(**model_inputs)

logits = results.logits[0].detach()
probabilities = torch.nn.functional.softmax(logits, dim=-1).detach()

pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int)
label_mask = (pred_mask & combined_mask.cpu().numpy()).flatten()

pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int).flatten()
label_mask = (pred_mask & combined_mask)
positive_probs = probabilities[:, 1].cpu().numpy()

transitions = np.diff(np.concatenate([[0], label_mask, [0]]))
starts = np.where(transitions == 1)[0]
ends = np.where(transitions == -1)[0]
Expand All @@ -332,7 +337,9 @@ def _score_document(
"text": span_text,
"score": float(span_prob)
})

print(span_text)
print("-"*100)
print("*"*100)
return spans_with_probs, int(label_mask.sum()), int(len(label_mask))

@weave.op
Expand Down

0 comments on commit 3b0a133

Please sign in to comment.