diff --git a/CHANGELOG.md b/CHANGELOG.md index 7812904e6..f88e71331 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,9 @@ # Changelog ## Unreleased -... +- Abstract `LanguageModel` class to integrate with LLMs from any API +- Abstract `ChatModel` class to integrate with chat models from any API +- Every `LanguageModel` supports echo to retrieve log probs for an expected completion given a prompt ### Breaking Changes ... diff --git a/src/intelligence_layer/core/__init__.py b/src/intelligence_layer/core/__init__.py index 520d88661..72082ba79 100644 --- a/src/intelligence_layer/core/__init__.py +++ b/src/intelligence_layer/core/__init__.py @@ -16,14 +16,18 @@ from .instruct import Instruct as Instruct from .instruct import InstructInput as InstructInput from .model import AlephAlphaModel as AlephAlphaModel +from .model import ChatModel as ChatModel from .model import CompleteInput as CompleteInput from .model import CompleteOutput as CompleteOutput from .model import ControlModel as ControlModel from .model import ExplainInput as ExplainInput from .model import ExplainOutput as ExplainOutput +from .model import LanguageModel as LanguageModel from .model import Llama2InstructModel as Llama2InstructModel +from .model import Llama3ChatModel as Llama3ChatModel from .model import Llama3InstructModel as Llama3InstructModel from .model import LuminousControlModel as LuminousControlModel +from .model import Message as Message from .prompt_template import Cursor as Cursor from .prompt_template import PromptItemCursor as PromptItemCursor from .prompt_template import PromptRange as PromptRange diff --git a/src/intelligence_layer/core/echo.py b/src/intelligence_layer/core/echo.py index 6d9400e6a..f9e9d7cac 100644 --- a/src/intelligence_layer/core/echo.py +++ b/src/intelligence_layer/core/echo.py @@ -1,12 +1,10 @@ from collections.abc import Sequence from typing import NewType -from aleph_alpha_client import Prompt, Tokens +from aleph_alpha_client import Prompt, Text from pydantic import BaseModel -from tokenizers import Encoding # type: ignore -from intelligence_layer.core.model import AlephAlphaModel, CompleteInput -from intelligence_layer.core.prompt_template import PromptTemplate +from intelligence_layer.core.model import AlephAlphaModel from intelligence_layer.core.task import Task, Token from intelligence_layer.core.tracer.tracer import TaskSpan @@ -73,56 +71,25 @@ def __init__(self, model: AlephAlphaModel) -> None: self._model = model def do_run(self, input: EchoInput, task_span: TaskSpan) -> EchoOutput: - # We tokenize the prompt separately so we don't have an overlap in the tokens. - # If we don't do this, the end of the prompt and expected completion can be merged into unexpected tokens. - expected_completion_tokens = self._tokenize(input.expected_completion) - prompt_template = PromptTemplate(self.PROMPT_TEMPLATE_STR) - prompt = prompt_template.to_rich_prompt( - prompt=prompt_template.embed_prompt(input.prompt), - expected_completion=prompt_template.placeholder( - Tokens.from_token_ids( - [token.token_id for token in expected_completion_tokens] - ) - ), - ) - output = self._model.complete( - CompleteInput( - prompt=prompt, - maximum_tokens=0, - log_probs=0, - tokens=True, - echo=True, - ), - task_span, - ) - assert output.completions[0].log_probs - log_prob_dicts = output.completions[0].log_probs[ - -len(expected_completion_tokens) : - ] - tokens_with_prob = [] - for token, log_prob in zip( - expected_completion_tokens, log_prob_dicts, strict=True - ): - assert token.token in log_prob - tokens_with_prob.append( - TokenWithLogProb( - token=token, - prob=LogProb(log_prob.get(token.token, 0.0) or 0.0), - ) + if len(input.prompt.items) != 1: + raise NotImplementedError( + "`Echo` currently only supports prompts with one item." ) - return EchoOutput(tokens_with_log_probs=tokens_with_prob) - def _tokenize(self, text: str) -> Sequence[Token]: - # Turns the expected output into list of token ids. Important so that we know how many tokens - # the label is and can retrieve the last N log probs for the label - tokenizer = self._model.get_tokenizer() - if tokenizer.pre_tokenizer: - tokenizer.pre_tokenizer.add_prefix_space = False - encoding: Encoding = tokenizer.encode(text) - return [ - Token( - token=tokenizer.decode([token_id], skip_special_tokens=False), - token_id=token_id, + if not isinstance(input.prompt.items[0], Text): + raise NotImplementedError( + "`Echo` currently only supports prompts that are of type `Text`." ) - for token_id in encoding.ids + + echo_output = self._model.echo( + input.prompt.items[0].text, input.expected_completion, task_span + ) + + tokens_with_prob = [ + TokenWithLogProb( + token=token, + prob=LogProb(log_prob or 0.0), + ) + for token, log_prob in echo_output ] + return EchoOutput(tokens_with_log_probs=tokens_with_prob) diff --git a/src/intelligence_layer/core/model.py b/src/intelligence_layer/core/model.py index 4c9a2907b..2ad31655a 100644 --- a/src/intelligence_layer/core/model.py +++ b/src/intelligence_layer/core/model.py @@ -1,14 +1,19 @@ import typing import warnings from abc import ABC, abstractmethod +from collections.abc import Sequence +from copy import deepcopy from functools import lru_cache -from typing import ClassVar, Optional +from typing import Any, ClassVar, Literal, Optional from aleph_alpha_client import ( CompletionRequest, CompletionResponse, ExplanationRequest, ExplanationResponse, + Prompt, + Text, + Tokens, ) from pydantic import BaseModel, ConfigDict from tokenizers import Encoding, Tokenizer # type: ignore @@ -18,7 +23,7 @@ LimitedConcurrencyClient, ) from intelligence_layer.core.prompt_template import PromptTemplate, RichPrompt -from intelligence_layer.core.task import Task +from intelligence_layer.core.task import Task, Token from intelligence_layer.core.tracer.tracer import TaskSpan, Tracer @@ -130,8 +135,59 @@ def limited_concurrency_client_from_env() -> LimitedConcurrencyClient: return LimitedConcurrencyClient.from_env() -class AlephAlphaModel: - """Abstract base class for the implementation of any model that uses the Aleph Alpha client. +class LanguageModel(ABC): + """Abstract base class to implement any LLM.""" + + def __init__(self, name: str) -> None: + self.name = name + + @abstractmethod + def generate(self, prompt: str, tracer: Tracer) -> str: + """A completion function that takes a prompt and generates a completion. + + Args: + prompt: The prompt to generate a completion for + tracer: Valid instance of a tracer + + Returns: + An LLM completion + """ + ... + + @abstractmethod + def echo( + self, prompt: str, expected_completion: str, tracer: Tracer + ) -> Sequence[tuple[Any, Optional[float]]]: + """Echos the log probs for each token of an expected completion given a prompt. + + Args: + prompt: The prompt to echo + expected_completion: The expected completion to get log probs for + tracer: Valid instance of a tracer + + Returns: + A list of tuples with token identifier and log probability + """ + ... + + +class Message(BaseModel): + role: Literal["system", "user", "assistant"] + content: str + + +class ChatModel(LanguageModel): + """Abstract base class to implement any model that supports chat.""" + + @abstractmethod + def generate_chat( + self, messages: list[Message], response_prefix: str | None, tracer: Tracer + ) -> str: + pass + + +class AlephAlphaModel(LanguageModel): + """Model-class for any model that uses the Aleph Alpha client. Any class of Aleph Alpha model is implemented on top of this base class. Exposes methods that are available to all models, such as `complete` and `tokenize`. It is the central place for @@ -146,11 +202,9 @@ class AlephAlphaModel: """ def __init__( - self, - name: str, - client: Optional[AlephAlphaClientProtocol] = None, + self, name: str, client: Optional[AlephAlphaClientProtocol] = None ) -> None: - self.name = name + super().__init__(name) self._client = ( limited_concurrency_client_from_env() if client is None else client ) @@ -164,6 +218,55 @@ def __init__( ) self._explain = _Explain(self._client, name) + def generate(self, prompt: str, tracer: Tracer) -> str: + complete_input = CompleteInput(prompt=Prompt.from_text(prompt)) + return self._complete.run(complete_input, tracer).completion + + def echo( + self, prompt: str, expected_completion: str, tracer: Tracer + ) -> Sequence[tuple[Token, Optional[float]]]: + expected_completion_encoding: Encoding = self.tokenize( + expected_completion, whitespace_prefix=False + ) + expected_completion_tokens = [ + Token( + token=self.get_tokenizer().decode( + [token_id], skip_special_tokens=False + ), + token_id=token_id, + ) + for token_id in expected_completion_encoding.ids + ] + + aa_prompt = Prompt( + items=[Text(prompt, []), Tokens(expected_completion_encoding.ids, [])] + ) + + logprob_index = 0 + output = self._complete.run( + CompleteInput( + prompt=aa_prompt, + maximum_tokens=0, + log_probs=logprob_index, + tokens=True, + echo=True, + ), + tracer, + ) + assert output.completions[0].log_probs + + return [ + ( + token, + list(log_prob_dict.values())[logprob_index] or 0.0, + ) + for token, log_prob_dict in zip( + expected_completion_tokens, + output.completions[0].log_probs[-len(expected_completion_tokens) :], + strict=False, + ) + ] + @property def context_size(self) -> int: # needed for proper caching without memory leaks @@ -186,8 +289,19 @@ def get_tokenizer(self) -> Tokenizer: return _cached_tokenizer(self._client, self.name) return _tokenizer(self._client, self.name) - def tokenize(self, text: str) -> Encoding: - return self.get_tokenizer().encode(text) + def get_tokenizer_no_whitespace_prefix(self) -> Tokenizer: + # needed for proper caching without memory leaks + if isinstance(self._client, typing.Hashable): + return _cached_tokenizer_no_whitespace_prefix(self._client, self.name) + return _tokenizer_no_whitespace_prefix(self._client, self.name) + + def tokenize(self, text: str, whitespace_prefix: bool = True) -> Encoding: + tokenizer = ( + self.get_tokenizer() + if whitespace_prefix + else self.get_tokenizer_no_whitespace_prefix() + ) + return tokenizer.encode(text) @lru_cache(maxsize=5) @@ -199,6 +313,27 @@ def _tokenizer(client: AlephAlphaClientProtocol, name: str) -> Tokenizer: return client.tokenizer(name) +@lru_cache(maxsize=5) +def _cached_tokenizer_no_whitespace_prefix( + client: AlephAlphaClientProtocol, name: str +) -> Tokenizer: + return _tokenizer_no_whitespace_prefix(client, name) + + +def _tokenizer_no_whitespace_prefix( + client: AlephAlphaClientProtocol, name: str +) -> Tokenizer: + tokenizer = client.tokenizer(name) + if tokenizer.pre_tokenizer: + copied_tokenizer = deepcopy(tokenizer) + copied_tokenizer.pre_tokenizer.add_prefix_space = False + return copied_tokenizer + + raise ValueError( + "Tokenizer does not support `.pre_tokenizer` and thus `.add_prefix_space` option." + ) + + @lru_cache(maxsize=10) def _cached_context_size(client: AlephAlphaClientProtocol, name: str) -> int: return _context_size(client, name) @@ -219,7 +354,7 @@ def _context_size(client: AlephAlphaClientProtocol, name: str) -> int: return context_size -class ControlModel(ABC, AlephAlphaModel): +class ControlModel(AlephAlphaModel, ABC): RECOMMENDED_MODELS: ClassVar[list[str]] = [] def __init__( @@ -370,6 +505,8 @@ class Llama3InstructModel(ControlModel): RECOMMENDED_MODELS: ClassVar[list[str]] = [ "llama-3-8b-instruct", "llama-3-70b-instruct", + "llama-3.1-8b-instruct", + "llama-3.1-70b-instruct", ] def __init__( @@ -383,20 +520,6 @@ def __init__( def eot_token(self) -> str: return "<|eot_id|>" - def _add_eot_token_to_stop_sequences(self, input: CompleteInput) -> CompleteInput: - # remove this once the API supports the llama-3 EOT_TOKEN - params = input.__dict__ - if isinstance(params["stop_sequences"], list): - if self.eot_token not in params["stop_sequences"]: - params["stop_sequences"].append(self.eot_token) - else: - params["stop_sequences"] = [self.eot_token] - return CompleteInput(**params) - - def complete(self, input: CompleteInput, tracer: Tracer) -> CompleteOutput: - input_with_eot = self._add_eot_token_to_stop_sequences(input) - return super().complete(input_with_eot, tracer) - def to_instruct_prompt( self, instruction: str, @@ -406,3 +529,65 @@ def to_instruct_prompt( return self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt( instruction=instruction, input=input, response_prefix=response_prefix ) + + +class AlephAlphaChatModel(ChatModel, AlephAlphaModel): + """Abstract base class for any model that supports chat and runs via the Aleph Alpha API.""" + + @abstractmethod + def to_chat_prompt( + self, messages: list[Message], response_prefix: str | None + ) -> RichPrompt: ... + + def generate_chat( + self, messages: list[Message], response_prefix: str | None, tracer: Tracer + ) -> str: + prompt = self.to_chat_prompt(messages, response_prefix) + prompt_item = prompt.items[0] + assert isinstance(prompt_item, Text) + + return self.generate(prompt_item.text, tracer) + + +class Llama3ChatModel(AlephAlphaChatModel): + """Chat model to be used for `llama-3-*` and `llama-3.1-*` models. + + Args: + name: The name of a valid llama-3 model. + Defaults to `llama-3-8b-instruct` + client: Aleph Alpha client instance for running model related API calls. + Defaults to :class:`LimitedConcurrencyClient` + """ + + CHAT_PROMPT_TEMPLATE = PromptTemplate( + """<|begin_of_text|>{% for message in messages %}<|start_header_id|>{{message.role}}<|end_header_id|> + +{% promptrange instruction %}{{message.content}}{% endpromptrange %}<|eot_id|>{% endfor %}<|start_header_id|>assistant<|end_header_id|> + +{% if response_prefix %}{{response_prefix}}{% endif %}""" + ) + + RECOMMENDED_MODELS: ClassVar[list[str]] = [ + "llama-3-8b-instruct", + "llama-3-70b-instruct", + "llama-3.1-8b-instruct", + "llama-3.1-70b-instruct", + ] + + def __init__( + self, + name: str = "llama-3.1-8b-instruct", + client: Optional[AlephAlphaClientProtocol] = None, + ) -> None: + super().__init__(name, client) + + @property + def eot_token(self) -> str: + return "<|eot_id|>" + + def to_chat_prompt( + self, messages: list[Message], response_prefix: str | None = None + ) -> RichPrompt: + return self.CHAT_PROMPT_TEMPLATE.to_rich_prompt( + messages=[m.model_dump() for m in messages], response_prefix=response_prefix + ) diff --git a/src/intelligence_layer/examples/qa/single_chunk_qa.py b/src/intelligence_layer/examples/qa/single_chunk_qa.py index c52bf9d88..aa6337cea 100644 --- a/src/intelligence_layer/examples/qa/single_chunk_qa.py +++ b/src/intelligence_layer/examples/qa/single_chunk_qa.py @@ -193,7 +193,7 @@ def _shift_highlight_ranges_to_input( def _get_no_answer_logit_bias( self, no_answer_str: str, no_answer_logit_bias: float ) -> dict[int, float]: - return {self._model.tokenize(no_answer_str).ids[0]: no_answer_logit_bias} + return {self._model.tokenize(no_answer_str, True).ids[0]: no_answer_logit_bias} def _generate_answer( self, diff --git a/tests/core/test_echo.py b/tests/core/test_echo.py index 6951a39bb..1c7945efc 100644 --- a/tests/core/test_echo.py +++ b/tests/core/test_echo.py @@ -77,9 +77,7 @@ def __init__( def tokenize_completion( expected_output: str, aleph_alpha_model: AlephAlphaModel ) -> Sequence[Token]: - tokenizer = aleph_alpha_model.get_tokenizer() - assert tokenizer.pre_tokenizer - tokenizer.pre_tokenizer.add_prefix_space = False + tokenizer = aleph_alpha_model.get_tokenizer_no_whitespace_prefix() encoding: tokenizers.Encoding = tokenizer.encode(expected_output) return [ Token( diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 57122cda2..cd44ed444 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -13,8 +13,10 @@ ControlModel, ExplainInput, Llama2InstructModel, + Llama3ChatModel, Llama3InstructModel, LuminousControlModel, + Message, NoOpTracer, ) from intelligence_layer.core.model import _cached_context_size, _cached_tokenizer @@ -157,3 +159,46 @@ def test_context_size_caching_works() -> None: _cached_context_size.cache_clear() different_result = another_model_instance.context_size assert context_size is not different_result + + +def test_chat_model_can_produce_chat_prompt() -> None: + client = DummyModelClient() # type: ignore + model = Llama3ChatModel("llama-3.1-8b-instruct", client) + messages = [ + Message(role="system", content="You are a nice assistant."), + Message(role="user", content="What's 2+2?"), + ] + response_prefix = "The answer is" + + prompt = model.to_chat_prompt(messages=messages, response_prefix=response_prefix) + + assert isinstance(prompt.items[0], Text) + + text_in_prompt = prompt.items[0].text + + assert ( + text_in_prompt + == """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a nice assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> + +What's 2+2?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +The answer is""" + ) + + +def test_aleph_alpha_model_can_echo( + model: ControlModel, no_op_tracer: NoOpTracer +) -> None: + prompt = "2 + 2 is" + expected_completion = " 4" + + echo_output = model.echo(prompt, expected_completion, no_op_tracer) + + assert len(echo_output) == 1 + assert echo_output[0][0].token == expected_completion + + log_probability = echo_output[0][1] + assert log_probability + assert 0 > log_probability > -5 diff --git a/tests/examples/qa/test_single_chunk_qa.py b/tests/examples/qa/test_single_chunk_qa.py index 436dcb810..2bd6e11a2 100644 --- a/tests/examples/qa/test_single_chunk_qa.py +++ b/tests/examples/qa/test_single_chunk_qa.py @@ -80,16 +80,10 @@ def test_qa_with_logit_bias_for_no_answer( question="When did he lose his mother?", ) output = single_chunk_qa.run(input, NoOpTracer()) - - # on CI, this is tokenized as "nonononono" rather than "no no no no no" - # Likely, this is because some test changes the tokenizer state to remove the whitespace - # We should fix this, but for now, I'll assert both - acceptable_answers = [ - " ".join([first_token] * max_tokens), - first_token * max_tokens, - ] answer = output.answer - assert answer == acceptable_answers[0] or answer == acceptable_answers[1] + + assert answer + assert "no" in answer.split()[0] def test_qa_highlights_will_not_become_out_of_bounds(