diff --git a/weave/scorers/relevance_scorer.py b/weave/scorers/relevance_scorer.py index 6c05f070d19..5b950076076 100644 --- a/weave/scorers/relevance_scorer.py +++ b/weave/scorers/relevance_scorer.py @@ -303,21 +303,26 @@ def _score_document( return_tensors="pt", return_special_tokens_mask=True ) - model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()} special_tokens_mask = model_inputs.pop("special_tokens_mask") - combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()) - + combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()).cpu().numpy().flatten() + # we should mask the query up to the sep token, + # on the combined mask we have to search for the first False + # TODO: Check that this is now wrong + false_indices = np.where(~combined_mask)[0] + start = false_indices[0] + end = false_indices[1] + combined_mask[start:end] = False results = self._model(**model_inputs) logits = results.logits[0].detach() probabilities = torch.nn.functional.softmax(logits, dim=-1).detach() - pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int) - label_mask = (pred_mask & combined_mask.cpu().numpy()).flatten() - + pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int).flatten() + label_mask = (pred_mask & combined_mask) positive_probs = probabilities[:, 1].cpu().numpy() + transitions = np.diff(np.concatenate([[0], label_mask, [0]])) starts = np.where(transitions == 1)[0] ends = np.where(transitions == -1)[0] @@ -332,7 +337,9 @@ def _score_document( "text": span_text, "score": float(span_prob) }) - + print(span_text) + print("-"*100) + print("*"*100) return spans_with_probs, int(label_mask.sum()), int(len(label_mask)) @weave.op