Skip to content

Commit

Permalink
feat: Llama3InstructModel
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Apr 24, 2024
1 parent da7a2b3 commit 0c26b8f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .model import CompleteInput as CompleteInput
from .model import CompleteOutput as CompleteOutput
from .model import ControlModel as ControlModel
from .model import Llama3InstructModel as Llama3InstructModel
from .model import LuminousControlModel as LuminousControlModel
from .prompt_template import Cursor as Cursor
from .prompt_template import PromptItemCursor as PromptItemCursor
Expand Down
55 changes: 55 additions & 0 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,58 @@ def to_instruct_prompt(
return self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt(
instruction=instruction, input=input, response_prefix=response_prefix
)


class Llama3InstructModel(ControlModel):
"""A llama-3-*-instruct model.
Args:
name: The name of a valid llama-3 model.
client: Aleph Alpha client instance for running model related API calls.
Defaults to the :class:`LimitedConcurrencyClient`
"""

INSTRUCTION_PROMPT_TEMPLATE = PromptTemplate(
"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{% promptrange instruction %}{{instruction}}{% endpromptrange %}{% if input %}
{% promptrange input %}{{input}}{% endpromptrange %}{% endif %}<eot_id><|start_header_id|>assistant<|end_header_id|>{% if response_prefix %}
{{response_prefix}}{% endif %}"""
)
EOT_TOKEN = "<|eot_id|>"

def __init__(
self,
name: Literal[
"llama-3-8b-instruct",
"llama-3-70b-instruct",
] = "llama-3-8b-instruct",
client: Optional[AlephAlphaClientProtocol] = None,
) -> None:
super().__init__(name, client)

def _add_eot_token_to_stop_sequences(self, input: CompleteInput) -> CompleteInput:
# remove this once the API supports the llama-3 EOT_TOKEN
params = input.__dict__
if isinstance(params["stop_sequences"], list):
if self.EOT_TOKEN not in params["stop_sequences"]:
params["stop_sequences"].append(self.EOT_TOKEN)
else:
params["stop_sequences"] = [self.EOT_TOKEN]
return CompleteInput(**params)

def complete(self, input: CompleteInput, tracer: Tracer) -> CompleteOutput:
input_with_eot = self._add_eot_token_to_stop_sequences(input)
return super().complete(input_with_eot, tracer)

def to_instruct_prompt(
self,
instruction: str,
input: Optional[str] = None,
response_prefix: Optional[str] = None,
) -> RichPrompt:
return self.INSTRUCTION_PROMPT_TEMPLATE.to_rich_prompt(
instruction=instruction, input=input, response_prefix=response_prefix
)
14 changes: 14 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AlephAlphaModel,
CompleteInput,
ControlModel,
Llama3InstructModel,
LuminousControlModel,
NoOpTracer,
)
Expand Down Expand Up @@ -53,3 +54,16 @@ def test_explain(model: ControlModel, no_op_tracer: NoOpTracer) -> None:
)
output = model.explain(explain_input, no_op_tracer)
assert output.explanations[0].items[0].scores[5].score > 1


def test_llama_3_model_works(no_op_tracer: NoOpTracer) -> None:
llama_3_model = Llama3InstructModel()

prompt = llama_3_model.to_instruct_prompt(
"Who likes pizza?",
"Marc and Jessica had pizza together. However, Marc hated it. He only agreed to the date because Jessica likes pizza so much.",
)

explain_input = CompleteInput(prompt=prompt)
output = llama_3_model.complete(explain_input, no_op_tracer)
assert "Jessica" in output.completion

0 comments on commit 0c26b8f

Please sign in to comment.