Skip to content

Commit

Permalink
WIP: Cache default client
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Feb 21, 2024
1 parent e138315 commit 74d4543
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
48 changes: 37 additions & 11 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CompleteInput(BaseModel, CompletionRequest, frozen=True):
class CompleteOutput(BaseModel, CompletionResponse, frozen=True):
"""The output of a `Complete` task."""

# Base model protects namespace model_ but this is a field in the completion response
# BaseModel protects namespace "model_" but this is a field in the CompletionResponse
model_config = ConfigDict(protected_namespaces=())

@staticmethod
Expand All @@ -46,12 +46,13 @@ def generated_tokens(self) -> int:
class _Complete(Task[CompleteInput, CompleteOutput]):
"""Performs a completion request with access to all possible request parameters.
Only use this task if non of the higher level tasks defined below works for
you, as your completion request does not fit to the use-cases the higher level ones represent or
you need to control request-parameters that are not exposed by them.
Only use this task for testing. Is wrapped by the AlephAlphaModel for sending
completion requests to the API.
Args:
client: Aleph Alpha client instance for running model related API calls.
model: The name of a valid model that can access an API using an implementation
of the AlephAlphaClientProtocol.
"""

def __init__(self, client: AlephAlphaClientProtocol, model: str) -> None:
Expand All @@ -70,14 +71,27 @@ def do_run(self, input: CompleteInput, task_span: TaskSpan) -> CompleteOutput:


class AlephAlphaModel(ABC):
"""Abstract base class for the implementation of any model that uses the Aleph Alpha client.
Any class of Aleph Alpha model is implemented on top of this base class. Exposes methods that
are available to all models, such as `complete` and `tokenize`. It is the central place for
all things that are physically interconnected with a model, such as its tokenizer or prompt
format used during training.
Args:
name: The name of a valid model that can access an API using an implementation
of the AlephAlphaClientProtocol.
client: Aleph Alpha client instance for running model related API calls.
"""

def __init__(
self,
model_name: str,
client: AlephAlphaClientProtocol = LimitedConcurrencyClient.from_token(),
name: str,
client: AlephAlphaClientProtocol,
) -> None:
self.name = name
self._client = client
self._complete = _Complete(self._client, model_name)
self.name = model_name
self._complete = _Complete(self._client, name)

def get_complete_task(self) -> Task[CompleteInput, CompleteOutput]:
return self._complete
Expand All @@ -102,7 +116,19 @@ def tokenize(self, text: str) -> Encoding:
return self.get_tokenizer().encode(text)


@lru_cache(maxsize=1)
def limited_concurrency_client_from_token() -> LimitedConcurrencyClient:
return LimitedConcurrencyClient.from_token()


class LuminousControlModel(AlephAlphaModel):
"""An Aleph Alpha control model of the second generation.
Args:
name: The name of a valid model second generation control model.
client: Aleph Alpha client instance for running model related API calls.
"""

INSTRUCTION_PROMPT_TEMPLATE = PromptTemplate(
"""{% promptrange instruction %}{{instruction}}{% endpromptrange %}
{% if input %}
Expand All @@ -113,17 +139,17 @@ class LuminousControlModel(AlephAlphaModel):

def __init__(
self,
model: Literal[
name: Literal[
"luminous-base-control",
"luminous-extended-control",
"luminous-supreme-control",
"luminous-base-control-20240215",
"luminous-extended-control-20240215",
"luminous-supreme-control-20240215",
],
client: AlephAlphaClientProtocol = LimitedConcurrencyClient.from_token(),
client: AlephAlphaClientProtocol = limited_concurrency_client_from_token(),
) -> None:
super().__init__(model, client)
super().__init__(name, client)

def to_instruct_prompt(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@fixture
def model(client: AlephAlphaClientProtocol) -> AlephAlphaModel:
return LuminousControlModel(client=client, model="luminous-base-control-20240215")
return LuminousControlModel(client=client, name="luminous-base-control-20240215")


def test_model_without_input(model: AlephAlphaModel, no_op_tracer: NoOpTracer) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/core/test_text_highlight.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from aleph_alpha_client import Image
from pytest import fixture, raises

from intelligence_layer.core.model import AlephAlphaModel
from intelligence_layer.core.prompt_template import PromptTemplate, RichPrompt
from intelligence_layer.core.text_highlight import TextHighlight, TextHighlightInput
from intelligence_layer.core.tracer import NoOpTracer
from intelligence_layer.connectors import AlephAlphaClientProtocol
from intelligence_layer.core import AlephAlphaModel, PromptTemplate, RichPrompt, TextHighlight, TextHighlightInput, NoOpTracer


class AlephAlphaVanillaModel(AlephAlphaModel):
Expand All @@ -18,8 +16,8 @@ def to_instruct_prompt(


@fixture
def aleph_alpha_vanilla_model() -> AlephAlphaVanillaModel:
return AlephAlphaVanillaModel("luminous-base")
def aleph_alpha_vanilla_model(client: AlephAlphaClientProtocol) -> AlephAlphaVanillaModel:
return AlephAlphaVanillaModel("luminous-base", client)


@fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/use_cases/summarize/test_recursive_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_recursive_summarize_stops_after_one_chunk(
recursive_counting_client: RecursiveCountingClient,
) -> None:
model = LuminousControlModel(
model="luminous-base-control-20240215", client=recursive_counting_client
name="luminous-base-control-20240215", client=recursive_counting_client
)

long_context_high_compression_summarize = SteerableLongContextSummarize(
Expand Down

0 comments on commit 74d4543

Please sign in to comment.