Skip to content

Commit

Permalink
WIP, fix almost all the linting
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Feb 21, 2024
1 parent 24887f1 commit f55b162
Show file tree
Hide file tree
Showing 23 changed files with 192 additions and 276 deletions.
6 changes: 4 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AlephAlphaClientProtocol,
LimitedConcurrencyClient,
)
from intelligence_layer.core.model import AlephAlphaModel
from intelligence_layer.core.tracer import NoOpTracer
from intelligence_layer.use_cases.classify.classify import (
ClassifyInput,
Expand All @@ -30,8 +31,9 @@ def client() -> AlephAlphaClientProtocol:

@app.post("/classify")
async def classify(
classify_input: ClassifyInput, client: AlephAlphaClientProtocol = Depends(client)
classify_input: ClassifyInput,
luminous_control_model: AlephAlphaModel = Depends(client),
) -> SingleLabelClassifyOutput:
classify = PromptBasedClassify(client)
classify = PromptBasedClassify(luminous_control_model)
classify_output = classify.run(classify_input, NoOpTracer())
return classify_output
8 changes: 5 additions & 3 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from .chunk import Chunk as Chunk
from .chunk import ChunkInput as ChunkInput
from .chunk import ChunkOutput as ChunkOutput
from .chunk import ChunkOverlapTask as ChunkOverlapTask
from .chunk import ChunkTask as ChunkTask
from .complete import Complete as Complete
from .complete import CompleteInput as CompleteInput
from .complete import CompleteOutput as CompleteOutput
from .detect_language import DetectLanguage as DetectLanguage
from .detect_language import DetectLanguageInput as DetectLanguageInput
from .detect_language import DetectLanguageOutput as DetectLanguageOutput
Expand All @@ -26,6 +24,10 @@
)
from .intelligence_app import AuthService as AuthService
from .intelligence_app import IntelligenceApp as IntelligenceApp
from .model import AlephAlphaModel as AlephAlphaModel
from .model import CompleteInput as CompleteInput
from .model import CompleteOutput as CompleteOutput
from .model import LuminousControlModel as LuminousControlModel
from .prompt_template import Cursor as Cursor
from .prompt_template import PromptItemCursor as PromptItemCursor
from .prompt_template import PromptRange as PromptRange
Expand Down
66 changes: 0 additions & 66 deletions src/intelligence_layer/core/complete.py

This file was deleted.

49 changes: 17 additions & 32 deletions src/intelligence_layer/core/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
)
from intelligence_layer.core.model import AlephAlphaModel, CompleteInput
from intelligence_layer.core.prompt_template import PromptTemplate
from intelligence_layer.core.task import Task, Token
from intelligence_layer.core.tracer import TaskSpan
Expand All @@ -27,12 +28,10 @@ class EchoInput(BaseModel):
prompt: The input text that serves as the starting point for the LLM.
expected_completion: The desired completion based on the prompt.
The likelihood of the tokens in this will be examined.
model: A valid Aleph Alpha model name.
"""

prompt: Prompt
expected_completion: str
model: str


class EchoOutput(BaseModel):
Expand Down Expand Up @@ -75,17 +74,14 @@ class EchoTask(Task[EchoInput, EchoOutput]):

PROMPT_TEMPLATE_STR: str = "{{prompt}}{{expected_completion}}"

def __init__(self, client: AlephAlphaClientProtocol) -> None:
def __init__(self, model: AlephAlphaModel) -> None:
super().__init__()
self._client = client
self._completion = Complete(client=client)
self._model = model

def do_run(self, input: EchoInput, task_span: TaskSpan) -> EchoOutput:
# We tokenize the prompt separately so we don't have an overlap in the tokens.
# If we don't do this, the end of the prompt and expected completion can be merged into unexpected tokens.
expected_completion_tokens = self._tokenize(
input.expected_completion, input.model
)
expected_completion_tokens = self._tokenize(input.expected_completion)
prompt_template = PromptTemplate(self.PROMPT_TEMPLATE_STR)
prompt = prompt_template.to_rich_prompt(
prompt=prompt_template.embed_prompt(input.prompt),
Expand All @@ -95,13 +91,18 @@ def do_run(self, input: EchoInput, task_span: TaskSpan) -> EchoOutput:
)
),
)
completion_input = CompleteInput(
request=self._completion_request(prompt=prompt),
model=input.model,
output = self._model.complete(
CompleteInput(
prompt=prompt,
maximum_tokens=0,
log_probs=0,
tokens=True,
echo=True,
),
task_span,
)
output = self._completion.run(completion_input, task_span)
assert output.response.completions[0].log_probs
log_prob_dicts = output.response.completions[0].log_probs[
assert output.completions[0].log_probs
log_prob_dicts = output.completions[0].log_probs[
-len(expected_completion_tokens) :
]
tokens_with_prob = []
Expand All @@ -117,22 +118,10 @@ def do_run(self, input: EchoInput, task_span: TaskSpan) -> EchoOutput:
)
return EchoOutput(tokens_with_log_probs=tokens_with_prob)

def _completion_request(
self,
prompt: Prompt,
) -> CompletionRequest:
return CompletionRequest(
prompt=prompt,
maximum_tokens=0,
log_probs=0,
tokens=True,
echo=True,
)

def _tokenize(self, text: str, model: str) -> Sequence[Token]:
def _tokenize(self, text: str) -> Sequence[Token]:
# Turns the expected output into list of token ids. Important so that we know how many tokens
# the label is and can retrieve the last N log probs for the label
tokenizer = self.tokenizer(model)
tokenizer = self._model.get_tokenizer()
if tokenizer.pre_tokenizer:
tokenizer.pre_tokenizer.add_prefix_space = False
encoding: Encoding = tokenizer.encode(text)
Expand All @@ -143,7 +132,3 @@ def _tokenize(self, text: str, model: str) -> Sequence[Token]:
)
for token_id in encoding.ids
]

@lru_cache
def tokenizer(self, model: str) -> Tokenizer:
return self._client.tokenizer(model)
30 changes: 30 additions & 0 deletions src/intelligence_layer/core/instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional
from pydantic import BaseModel
from intelligence_layer.core.model import AlephAlphaModel, CompleteInput, CompleteOutput

from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer import TaskSpan


class InstructInput(BaseModel):
instruction: str
input: Optional[str] = None
response_prefix: Optional[str] = None
maximum_tokens: int = 128


class Instruct(Task[InstructInput, CompleteOutput]):
def __init__(self, model: AlephAlphaModel) -> None:
super().__init__()
self._model = model

def do_run(self, input: InstructInput, task_span: TaskSpan) -> CompleteOutput:
prompt = self._model.to_instruct_prompt(
instruction=input.instruction,
input=input.input,
response_prefix=input.response_prefix
)
return self._model.complete(CompleteInput(
prompt=prompt,
maximum_tokens=input.maximum_tokens
), task_span)
1 change: 1 addition & 0 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, client: AlephAlphaClientProtocol, model: str) -> None:
self._model = model

def do_run(self, input: CompleteInput, task_span: TaskSpan) -> CompleteOutput:
task_span.log("Model", self._model)
return CompleteOutput.from_completion_response(
self._client.complete(
request=input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Question,
RecordData,
)
from intelligence_layer.core.instruct import InstructInput
from intelligence_layer.core.model import CompleteOutput
from intelligence_layer.evaluation import (
ArgillaEvaluationRepository,
DatasetRepository,
Expand Down Expand Up @@ -39,7 +41,7 @@ class AggregatedInstructComparison(BaseModel):
class InstructComparisonArgillaEvaluator(
ArgillaEvaluator[
InstructInput,
PromptOutput,
CompleteOutput,
None,
AggregatedInstructComparison,
]
Expand Down Expand Up @@ -91,11 +93,11 @@ def __init__(
def _to_record(
self,
example: Example[InstructInput, None],
*example_outputs: SuccessfulExampleOutput[PromptOutput],
*example_outputs: SuccessfulExampleOutput[CompleteOutput],
) -> Sequence[RecordData]:
def create_record_data(
first: SuccessfulExampleOutput[PromptOutput],
second: SuccessfulExampleOutput[PromptOutput],
first: SuccessfulExampleOutput[CompleteOutput],
second: SuccessfulExampleOutput[CompleteOutput],
) -> RecordData:
if random.choice([True, False]):
first, second = second, first
Expand Down
26 changes: 9 additions & 17 deletions src/intelligence_layer/use_cases/classify/prompt_based_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AlephAlphaClientProtocol,
)
from intelligence_layer.core.echo import EchoInput, EchoTask, TokenWithLogProb
from intelligence_layer.core.model import AlephAlphaModel, LuminousControlModel
from intelligence_layer.core.prompt_template import RichPrompt
from intelligence_layer.core.task import Task, Token
from intelligence_layer.core.tracer import TaskSpan
Expand Down Expand Up @@ -64,30 +65,23 @@ class PromptBasedClassify(Task[ClassifyInput, SingleLabelClassifyOutput]):
>>> output = task.run(input, tracer)
"""

PROMPT_TEMPLATE: str = """### Instruction:
Identify a class that describes the text adequately.
Reply with only the class label.
### Input:
{{text}}
### Response:"""
INSTRUCTION: str = """Identify a class that describes the text adequately.
Reply with only the class label."""

def __init__(
self, client: AlephAlphaClientProtocol, model: str = "luminous-base-control"
self,
model: AlephAlphaModel = LuminousControlModel("luminous-base-control-20240215"),
) -> None:
super().__init__()
self._client = client
self._echo_task = EchoTask(client)
self.model = model
self._echo_task = EchoTask(model)
self._model = model

def do_run(
self, input: ClassifyInput, task_span: TaskSpan
) -> SingleLabelClassifyOutput:
log_probs_per_label = self._log_probs_per_label(
text_to_classify=input.chunk,
labels=input.labels,
model=self.model,
task_span=task_span,
)
task_span.log("Log probs per label", log_probs_per_label)
Expand All @@ -101,17 +95,15 @@ def _log_probs_per_label(
self,
text_to_classify: str,
labels: frozenset[str],
model: str,
task_span: TaskSpan,
) -> Mapping[str, Sequence[TokenWithLogProb]]:
prompt = PromptTemplate(template_str=self.PROMPT_TEMPLATE).to_prompt(
text=text_to_classify
prompt = self._model.to_instruct_prompt(
instruction=self.INSTRUCTION, input=text_to_classify
)
inputs = (
EchoInput(
prompt=prompt,
expected_completion=self._prepare_label_for_echo_task(label),
model=model,
)
for label in labels
)
Expand Down
16 changes: 6 additions & 10 deletions src/intelligence_layer/use_cases/intelligence_starter_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LimitedConcurrencyClient,
)
from intelligence_layer.core import IntelligenceApp
from intelligence_layer.core.model import LuminousControlModel
from intelligence_layer.use_cases.classify.classify import ClassifyInput
from intelligence_layer.use_cases.classify.prompt_based_classify import (
PromptBasedClassify,
Expand All @@ -24,26 +25,21 @@


class IntelligenceStarterApp(IntelligenceApp):
def __init__(self, fast_api_app: FastAPI, client: AlephAlphaClientProtocol) -> None:
def __init__(self, fast_api_app: FastAPI) -> None:
super().__init__(fast_api_app)
prompt_based_classify = PromptBasedClassify(client)
prompt_based_classify = PromptBasedClassify()
self.register_task(prompt_based_classify, ClassifyInput, "/classify")
long_chunk_qa = LongContextQa(client)
long_chunk_qa = LongContextQa()
self.register_task(long_chunk_qa, LongContextQaInput, "/qa")
summarize = SteerableLongContextSummarize(
client, max_generated_tokens=128, max_tokens_per_chunk=512
max_generated_tokens=512, max_tokens_per_chunk=1024
)
self.register_task(summarize, LongContextSummarizeInput, "/summarize")


def main() -> None:
load_dotenv()
aa_token = os.getenv("AA_TOKEN")
assert aa_token
aa_client = Client(aa_token)
client = LimitedConcurrencyClient(aa_client)
fast_api = FastAPI()
app = IntelligenceStarterApp(fast_api, client)
app = IntelligenceStarterApp(fast_api)
app.serve()


Expand Down
2 changes: 1 addition & 1 deletion src/intelligence_layer/use_cases/qa/long_context_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class LongContextQa(Task[LongContextQaInput, MultipleChunkQaOutput]):

def __init__(
self,
max_tokens_per_chunk: int = 512,
max_tokens_per_chunk: int = 1024,
k: int = 4,
model: AlephAlphaModel = LuminousControlModel(
"luminous-supreme-control-20240215"
Expand Down
Loading

0 comments on commit f55b162

Please sign in to comment.