Skip to content

Commit

Permalink
more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 10, 2024
1 parent 85feae9 commit 39dd40e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
7 changes: 7 additions & 0 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ async def eval_example(example: dict) -> dict:
n_complete += 1
if verbose:
print(f"Evaluated {n_complete} of {len(trial_rows)} examples")
else:
# Print progress at 25%, 50%, 75% and 100%
total_rows = len(trial_rows)
progress_milestones = [total_rows // 4, total_rows // 2, 3 * total_rows // 4, total_rows]
if n_complete in progress_milestones:
percent_complete = int((n_complete / total_rows) * 100)
print(f"Evaluated {percent_complete}% of examples")
# status.update(
# f"Evaluating... {duration:.2f}s [{n_complete} / {len(self.dataset.rows)} complete]" # type:ignore
# )
Expand Down
4 changes: 3 additions & 1 deletion weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,13 @@ def _score_document(
combined_mask = ~((model_inputs["input_ids"] == 2).bool() | special_tokens_mask.bool()).cpu().numpy().flatten()
# we should mask the query up to the sep token,
# on the combined mask we have to search for the first False
# TODO: Check that this is now wrong
# TODO: Check that this is not wrong
false_indices = np.where(~combined_mask)[0]
start = false_indices[0]
end = false_indices[1]
combined_mask[start:end] = False


results = self._model(**model_inputs)

logits = results.logits[0].detach()
Expand Down

0 comments on commit 39dd40e

Please sign in to comment.