Skip to content

Commit

Permalink
Fix truncation for hallu, coherence, cont rele
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 committed Dec 11, 2024
1 parent b37a385 commit 60bf3e0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
7 changes: 6 additions & 1 deletion weave/scorers/coherence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CoherenceScorer(Scorer):

device: str = None
model_name_or_path: str = ""
model_max_length: int = 1024
base_url: Optional[str] = None
_classifier: Any = PrivateAttr()
_label2id: dict[str, int] = PrivateAttr()
Expand All @@ -46,7 +47,11 @@ def model_post_init(self, __context: Any) -> None:
)

self._classifier = pipeline(
task="sentiment-analysis", model=self._local_model_path, device=self.device
task="sentiment-analysis",
model=self._local_model_path,
device=self.device,
max_length=self.model_max_length,
truncation=True
)
self._label2id = {
"Completely Incoherent": 0,
Expand Down
9 changes: 7 additions & 2 deletions weave/scorers/context_relevance_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class ContextRelevanceScorer(Scorer):
base_url: Optional[str] = None
device: str = "cpu"
threshold: float = 0.7
model_max_length: int = 1280
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()

Expand All @@ -283,7 +284,10 @@ def model_post_init(self, __context: Any) -> None:
self._model = AutoModelForTokenClassification.from_pretrained(
self._local_model_path, device_map=self.device
)
self._tokenizer = AutoTokenizer.from_pretrained(self._local_model_path)
self._tokenizer = AutoTokenizer.from_pretrained(
self._local_model_path,
model_max_length=self.model_max_length,
)
self._model.eval()
self.device = set_device(self.device)

Expand All @@ -298,7 +302,8 @@ def _score_document(
input_text = query + f" {self._tokenizer.sep_token} " + document
model_inputs = self._tokenizer(
input_text,
truncation=True,
truncation=True,
max_length=self.model_max_length,
padding=False,
return_tensors="pt",
return_special_tokens_mask=True
Expand Down
24 changes: 23 additions & 1 deletion weave/scorers/hallucination_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def model_post_init(self, __context) -> None:
trust_remote_code=True,
).to(self.device)
self.tokenizer = self.llm_model.tokenzier
self.tokenizer.model_max_length = self.model_max_length
else:
self.llm_model = AutoModelForCausalLM.from_pretrained(
self._local_model_path, torch_dtype="bfloat16"
Expand Down Expand Up @@ -347,7 +348,28 @@ def score(self, query: str, context: str, output: str) -> dict:
res = res["data"]
else:
if self.use_hhem:
pairs = [(query + "\n\n" + context, output)]
inps = query + "\n\n" + context
outs = output

inps_toks = self.tokenizer(inps, truncation=False)
outs_toks = self.tokenizer(outs, truncation=False)

len_inps = len(inps_toks.input_ids)
len_outs = len(outs_toks.input_ids)
if len_inps + len_outs > self.model_max_length:
print(f"inps and outs > max_lenth: {len_inps + len_outs}")
if len_outs < self.model_max_length - 1000:
inp_remaining = self.model_max_length - (len_outs + 975)
inps_input_ids = inps_toks.input_ids[:inp_remaining]
out_input_ids = outs_toks.input_ids
else:
inps_input_ids = inps_toks.input_ids[:975]
out_input_ids = outs_toks.input_ids[:self.model_max_length - 1025]

inps = self.tokenizer.decode(inps_input_ids)
outs = self.tokenizer.decode(out_input_ids)

pairs = [(inps, outs)]
pred = self.llm_model.predict(pairs)
score = pred.item()
return {
Expand Down

0 comments on commit 60bf3e0

Please sign in to comment.