From 60bf3e0314d40c26201200a64afe4e3fdcc36961 Mon Sep 17 00:00:00 2001 From: Morgan McGuire Date: Wed, 11 Dec 2024 19:41:02 +0000 Subject: [PATCH] Fix truncation for hallu, coherence, cont rele --- weave/scorers/coherence_scorer.py | 7 ++++++- weave/scorers/context_relevance_scorer.py | 9 +++++++-- weave/scorers/hallucination_scorer.py | 24 ++++++++++++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/weave/scorers/coherence_scorer.py b/weave/scorers/coherence_scorer.py index 1a286513466..b3673978201 100644 --- a/weave/scorers/coherence_scorer.py +++ b/weave/scorers/coherence_scorer.py @@ -27,6 +27,7 @@ class CoherenceScorer(Scorer): device: str = None model_name_or_path: str = "" + model_max_length: int = 1024 base_url: Optional[str] = None _classifier: Any = PrivateAttr() _label2id: dict[str, int] = PrivateAttr() @@ -46,7 +47,11 @@ def model_post_init(self, __context: Any) -> None: ) self._classifier = pipeline( - task="sentiment-analysis", model=self._local_model_path, device=self.device + task="sentiment-analysis", + model=self._local_model_path, + device=self.device, + max_length=self.model_max_length, + truncation=True ) self._label2id = { "Completely Incoherent": 0, diff --git a/weave/scorers/context_relevance_scorer.py b/weave/scorers/context_relevance_scorer.py index 21d8b8f03f2..fb6994792ba 100644 --- a/weave/scorers/context_relevance_scorer.py +++ b/weave/scorers/context_relevance_scorer.py @@ -261,6 +261,7 @@ class ContextRelevanceScorer(Scorer): base_url: Optional[str] = None device: str = "cpu" threshold: float = 0.7 + model_max_length: int = 1280 _model: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() @@ -283,7 +284,10 @@ def model_post_init(self, __context: Any) -> None: self._model = AutoModelForTokenClassification.from_pretrained( self._local_model_path, device_map=self.device ) - self._tokenizer = AutoTokenizer.from_pretrained(self._local_model_path) + self._tokenizer = AutoTokenizer.from_pretrained( + self._local_model_path, + model_max_length=self.model_max_length, + ) self._model.eval() self.device = set_device(self.device) @@ -298,7 +302,8 @@ def _score_document( input_text = query + f" {self._tokenizer.sep_token} " + document model_inputs = self._tokenizer( input_text, - truncation=True, + truncation=True, + max_length=self.model_max_length, padding=False, return_tensors="pt", return_special_tokens_mask=True diff --git a/weave/scorers/hallucination_scorer.py b/weave/scorers/hallucination_scorer.py index 4644b4e3bfd..752217da5c0 100644 --- a/weave/scorers/hallucination_scorer.py +++ b/weave/scorers/hallucination_scorer.py @@ -308,6 +308,7 @@ def model_post_init(self, __context) -> None: trust_remote_code=True, ).to(self.device) self.tokenizer = self.llm_model.tokenzier + self.tokenizer.model_max_length = self.model_max_length else: self.llm_model = AutoModelForCausalLM.from_pretrained( self._local_model_path, torch_dtype="bfloat16" @@ -347,7 +348,28 @@ def score(self, query: str, context: str, output: str) -> dict: res = res["data"] else: if self.use_hhem: - pairs = [(query + "\n\n" + context, output)] + inps = query + "\n\n" + context + outs = output + + inps_toks = self.tokenizer(inps, truncation=False) + outs_toks = self.tokenizer(outs, truncation=False) + + len_inps = len(inps_toks.input_ids) + len_outs = len(outs_toks.input_ids) + if len_inps + len_outs > self.model_max_length: + print(f"inps and outs > max_lenth: {len_inps + len_outs}") + if len_outs < self.model_max_length - 1000: + inp_remaining = self.model_max_length - (len_outs + 975) + inps_input_ids = inps_toks.input_ids[:inp_remaining] + out_input_ids = outs_toks.input_ids + else: + inps_input_ids = inps_toks.input_ids[:975] + out_input_ids = outs_toks.input_ids[:self.model_max_length - 1025] + + inps = self.tokenizer.decode(inps_input_ids) + outs = self.tokenizer.decode(out_input_ids) + + pairs = [(inps, outs)] pred = self.llm_model.predict(pairs) score = pred.item() return {