Skip to content

Commit

Permalink
update output to res
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 committed Dec 10, 2024
1 parent 5457b3c commit 95f266a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 381 deletions.
12 changes: 5 additions & 7 deletions weave/scorers/hallucination_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,12 @@ def score(self, query: str, context: str, output: str) -> dict:
output=output,
)
if self.base_url:
output = self._score_via_api(messages)
output = output["data"]

res = self._score_via_api(messages)
res = res["data"]
else:
if self.use_hhem:
pairs = [(query + "\n\n" + context, output)]
pred = self.llm_model.predict(pairs)

score = pred.item()
return {
"flagged": score <= self.hhem_score_threshold,
Expand All @@ -372,7 +370,7 @@ def score(self, query: str, context: str, output: str) -> dict:
with torch.no_grad():
self.llm_model.eval()

output = self.llm_model.generate(
res = self.llm_model.generate(
inp_tokenized["input_ids"],
max_new_tokens=self.max_new_tokens,
attention_mask=inp_tokenized["attention_mask"],
Expand All @@ -388,7 +386,7 @@ def score(self, query: str, context: str, output: str) -> dict:
false_token = 4245

input_length = inp_tokenized["input_ids"].shape[1]
completion_tokens = output[0][input_length:].tolist()
completion_tokens = res[0][input_length:].tolist()

is_hallucination = true_token in completion_tokens
result = {
Expand All @@ -411,7 +409,7 @@ def score(self, query: str, context: str, output: str) -> dict:
{
"completion": completion,
"completion_tokens": completion_tokens,
"total_tokens": len(output[0]),
"total_tokens": len(res[0]),
"total_completion_tokens": len(completion_tokens),
"scorer_worked": scorer_worked,
}
Expand Down
Loading

0 comments on commit 95f266a

Please sign in to comment.