Skip to content

Commit

Permalink
feat: add sentence based context relevance scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
parambharat committed Dec 16, 2024
1 parent 4e2f152 commit 98df0a3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 97 deletions.
149 changes: 55 additions & 94 deletions weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def model_post_init(self, __context: Any) -> None:
if os.path.isdir(self.model_name_or_path):
self._local_model_path = self.model_name_or_path
else:
self._local_model_path = download_model(
MODEL_PATHS["relevance_scorer"]
)
self._local_model_path = download_model(MODEL_PATHS["relevance_scorer"])

self._classifier = pipeline(
task="text-generation", model=self._local_model_path, device=self.device
Expand Down Expand Up @@ -217,6 +215,7 @@ def score(
)
return self.score_messages(messages)


class ContextRelevanceScorer(Scorer):
"""
A scorer that evaluates the relevance of model outputs relative to input queries and context.
Expand All @@ -237,7 +236,7 @@ class ContextRelevanceScorer(Scorer):
- flagged (bool): Whether the output was flagged as irrelevant
- extras (dict): Contains:
- score (float): Overall relevance score
- all_spans (list, optional): If verbose=True, includes list of relevant
- all_spans (list, optional): If verbose=True, includes list of relevant
text spans and their scores
Example:
Expand All @@ -257,126 +256,88 @@ class ContextRelevanceScorer(Scorer):
}
}
"""

model_name_or_path: str = None
base_url: Optional[str] = None
device: str = "cpu"
threshold: float = 0.7
model_max_length: int = 1280
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
threshold: float = 0.5
model_max_length: int = 2048
_max_num_sentences: int = 20
_classifier: Any = PrivateAttr()

def model_post_init(self, __context: Any) -> None:
try:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers import pipeline
import nltk

nltk.download("punkt_tab")
from nltk.tokenize import sent_tokenize
except ImportError:
print(
"The `transformers` and `torch` packages are required to use the ContextRelevanceScorer, please run `pip install transformers torch`"
"The `transformers`, `torch` and `nltk` packages are required to use the ContextRelevanceScorer, please run `pip install transformers torch nltk`"
)
"""Initialize the model, tokenizer and device after pydantic initialization."""
if os.path.isdir(self.model_name_or_path):
self._local_model_path = self.model_name_or_path
else:
self._local_model_path = download_model(
MODEL_PATHS["relevance_scorer"]
)
self._local_model_path = download_model(MODEL_PATHS["relevance_scorer"])
assert self._local_model_path, "Model path not found"
self._model = AutoModelForTokenClassification.from_pretrained(
self._local_model_path, device_map=self.device
)
self._tokenizer = AutoTokenizer.from_pretrained(
self._local_model_path,
model_max_length=self.model_max_length,
)
self._model.eval()

self.device = set_device(self.device)
self._classifier = pipeline(
"context-relevance",
model=self._local_model_path,
trust_remote_code=True,
device=self.device,
)

def _score_document(
self,
query: str,
document: str,
threshold: float) -> tuple[list[dict[str, Any]], int, int]:
self, query: str, document: str, response: str
) -> list[dict[str, Any]]:
"""Score a single document."""
import torch
with torch.no_grad():
input_text = query + f" {self._tokenizer.sep_token} " + document
model_inputs = self._tokenizer(
input_text,
truncation=True,
max_length=self.model_max_length,
padding=False,
return_tensors="pt",
return_special_tokens_mask=True
)

model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

special_tokens_mask = model_inputs.pop("special_tokens_mask")
combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()).cpu().numpy().flatten()
# we should mask the query up to the sep token,
# on the combined mask we have to search for the first False
# TODO: Check that this is not wrong
false_indices = np.where(~combined_mask)[0]
start = false_indices[0]
end = false_indices[1]
combined_mask[start:end] = False


results = self._model(**model_inputs)

logits = results.logits[0].detach()
probabilities = torch.nn.functional.softmax(logits, dim=-1).detach()

pred_mask = (probabilities[:,1] > threshold).cpu().numpy().astype(int).flatten()
label_mask = (pred_mask & combined_mask)

positive_probs = probabilities[:, 1].cpu().numpy()
transitions = np.diff(np.concatenate([[0], label_mask, [0]]))
starts = np.where(transitions == 1)[0]
ends = np.where(transitions == -1)[0]

spans_with_probs = []
token_ids = model_inputs["input_ids"].cpu().numpy()[0]

for start, end in zip(starts, ends):
span_text = self._tokenizer.decode(token_ids[start:end])
span_prob = positive_probs[start:end].mean()
spans_with_probs.append({
"text": span_text,
"score": float(span_prob)
})

return spans_with_probs, int(label_mask.sum()), int(len(label_mask))

document_sentences = document.split("\n")
document_sentences = [sent_tokenize(doc) for doc in document_sentences]
document_sentences = [s for doc in document_sentences for s in doc]
context_scores = []
for batch in range(0, len(document_sentences), self._max_num_sentences):
inputs = {
"question": query,
"context": document_sentences[batch : batch + self._max_num_sentences],
"response": response,
}
res = self._classifier(inputs, threshold=self.threshold)[0]
context_scores.extend(res["sentences"])

return context_scores

@weave.op
def score(
self,
output: str,
query: str,
context: Union[str, List[str]],
verbose: bool = False
verbose: bool = False,
) -> Dict[str, Any]:
"""Score multiple documents and compute weighted average relevance."""
all_spans = []
total_weighted_score = 0.0
total_length = 0

if isinstance(context, str):
context = [context]
context_scores = []
for doc in context:
spans, relevant_tokens, total_tokens = self._score_document(query, doc, self.threshold)

all_spans.extend(spans)

if total_tokens > 0:
doc_score = relevant_tokens / total_tokens
doc_weight = total_tokens
total_weighted_score += doc_score * doc_weight
total_length += total_tokens

final_score = total_weighted_score / total_length if total_length > 0 else 0.0
doc_scores = self._score_document(
query=query, document=doc, response=output
)
context_scores.extend(doc_scores)

relevant_sentences = [
score for score in context_scores if score["label"] == "relevant"
]
for score in relevant_sentences[:]:
score.update({"span": score.pop("sentence")})
final_score = len(relevant_sentences) / len(context_scores)
res = {"flagged": final_score > self.threshold}
res['extras'] = {'score': final_score}
res["extras"] = {"score": final_score}
if verbose:
res['extras']['all_spans'] = all_spans
res["extras"]["all_spans"] = relevant_sentences

return res
9 changes: 6 additions & 3 deletions weave/scorers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def embed(
if "openai" in client_type:
response = client.embeddings.create(model=model_id, input=texts, **kwargs)
if inspect.iscoroutine(response):
raise ValueError("Async client used with sync function. Use await with async clients.")
raise ValueError(
"Async client used with sync function. Use await with async clients."
)
return [embedding.embedding for embedding in response.data]
elif "mistral" in client_type:
response = client.embeddings.create(model=model_id, inputs=texts, **kwargs)
Expand All @@ -129,6 +131,7 @@ def set_device(device: Optional[str] = None) -> str:

def download_model(model_name_or_path: str, local_dir: str = "weave_models") -> str:
from wandb import Api

api = Api()
art = api.artifact(
type="model",
Expand All @@ -148,8 +151,8 @@ def download_model(model_name_or_path: str, local_dir: str = "weave_models") ->
"coherence_scorer": "c-metrics/weave-scorers/coherence_scorer:v0",
"toxicity_scorer": "c-metrics/weave-scorers/toxicity_scorer:v0",
"bias_scorer": "c-metrics/weave-scorers/bias_scorer:v0",
"relevance_scorer": "c-metrics/context-relevance-scorer/relevance_scorer:v0",
"llamaguard": "c-metrics/weave-scorers/llamaguard:v0"
"relevance_scorer": "c-metrics/context-relevance-scorer/relevance_scorer:v2",
"llamaguard": "c-metrics/weave-scorers/llamaguard:v0",
}


Expand Down

0 comments on commit 98df0a3

Please sign in to comment.