Skip to content

Commit

Permalink
[easy][extensions][py][hf] 3/n Local Inference model parser changes (#…
Browse files Browse the repository at this point in the history
…597)

[easy][extensions][py][hf] 3/n Local Inference model parser changes







## What


This diff is a copy paste of the diff below except for one difference:
1. ModelParser `id` name and class name

HuggingFace Extension contains a `local_inference` and a
`remote_inference_client`.

`Remote_inference_client` and the core huggingface model parser in the
sdk are the same implementation with the one difference of model parser
id. The extension model parser id is `HuggingFaceTextGenerationClient`.

Note: Both of these model parsers are for Text Generation only. Other
huggingface tasks are not supported on the inference-endpoint by
aiconfig yet.

## Why

Some context onto why there's two extensions, one was supposed to serve
as a guide and easy to write model parser. This guide has yet to be
written.

## Testplan

<img width="1396" alt="M"
src="https://github.com/lastmile-ai/aiconfig/assets/141073967/5638c958-aa3f-4ad4-af1c-d4e733ec5722">

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/597).
* __->__ #597
* #588
* #587
  • Loading branch information
Ankush-lastmile authored Dec 23, 2023
2 parents 307f791 + 108c62a commit 26c6443
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.util.config_utils import get_api_key_from_environment

# HuggingFace API imports
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import (
TextGenerationResponse,
TextGenerationStreamResponse,
)

from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig import CallbackEvent

from aiconfig.util.params import resolve_prompt

# ModelParser Utils
# Type hint imports
from aiconfig import (
ExecuteResult,
InferenceOptions,
Output,
ParameterizedModelParser,
Prompt,
PromptMetadata,
get_api_key_from_environment,
resolve_prompt,
)

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand All @@ -29,7 +28,7 @@
# Step 1: define Helpers


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
"""
Refines the completion params for the HF text generation api. Removes any unsupported params.
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
Expand Down Expand Up @@ -93,7 +92,8 @@ def construct_stream_output(

index = 0 # HF Text Generation api doesn't support multiple outputs
delta = data
options.stream_callback(delta, accumulated_message, index)
if options and options.stream_callback:
options.stream_callback(delta, accumulated_message, index)

output = ExecuteResult(
**{
Expand All @@ -107,11 +107,10 @@ def construct_stream_output(
return output


def construct_regular_output(response, response_includes_details: bool) -> Output:
def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {}
data = response
if response_includes_details:
response: TextGenerationResponse # Expect response to be of type TextGenerationResponse
data = response.generated_text
metadata = {"details": response.details}

Expand All @@ -126,9 +125,9 @@ def construct_regular_output(response, response_includes_details: bool) -> Outpu
return output


class HuggingFaceTextGenerationClient(ParameterizedModelParser):
class HuggingFaceTextGenerationParser(ParameterizedModelParser):
"""
A model parser for HuggingFace text generation models using inference client API.
A model parser for HuggingFace text generation models.
"""

def __init__(self, model_id: str = None, use_api_token=False):
Expand All @@ -138,12 +137,12 @@ def __init__(self, model_id: str = None, use_api_token=False):
no_token (bool): Whether or not to require an API token. Set to False if you don't have an api key.
Returns:
HuggingFaceTextGenerationClient: The HuggingFaceTextGenerationClient object.
HuggingFaceTextParser: The HuggingFaceTextParser object.
Usage:
1. Create a new model parser object with the model ID of the model to use.
parser = HuggingFaceTextGenerationClient("mistralai/Mistral-7B-Instruct-v0.1", use_api_token=False)
parser = HuggingFaceTextParser("mistralai/Mistral-7B-Instruct-v0.1", use_api_token=False)
2. Add the model parser to the registry.
config.register_model_parser(parser)
Expand All @@ -166,14 +165,14 @@ def id(self) -> str:
"""
return "HuggingFaceTextGenerationClient"

def serialize(
async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict] = None,
**kwargs
) -> Prompt:
parameters: Optional[dict[Any, Any]] = None,
**kwargs,
) -> list[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Expand All @@ -184,6 +183,19 @@ def serialize(
Returns:
str: Serialized representation of the prompt and inference settings.
"""
await ai_config.callback_manager.run_callbacks(
CallbackEvent(
"on_serialize_start",
__name__,
{
"prompt_name": prompt_name,
"data": data,
"parameters": parameters,
"kwargs": kwargs,
},
)
)

data = copy.deepcopy(data)

# assume data is completion params for HF text generation
Expand All @@ -192,23 +204,29 @@ def serialize(
# Prompt is handled, remove from data
data.pop("prompt", None)

model_metadata = ai_config.generate_model_metadata(data, self.id())
prompts = []

model_metadata = ai_config.get_model_metadata(data, self.id())
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
)
return [prompt]

prompts.append(prompt)

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

return prompts

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
options,
params: Optional[Dict] = {},
) -> Dict:
params: Optional[dict[Any, Any]] = {},
) -> dict[Any, Any]:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Expand All @@ -219,6 +237,8 @@ async def deserialize(
Returns:
dict: Model-specific completion parameters.
"""
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

resolved_prompt = resolve_prompt(prompt, params, aiconfig)

# Build Completion data
Expand All @@ -228,10 +248,12 @@ async def deserialize(

completion_data["prompt"] = resolved_prompt

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig, options, parameters
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> Output:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
Expand All @@ -244,12 +266,24 @@ async def run_inference(
Returns:
InferenceResponse: The response from the model.
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
__name__,
{"prompt": prompt, "options": options, "parameters": parameters},
)
)

completion_data = await self.deserialize(prompt, aiconfig, parameters)

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
stream = True # Default value
if options is not None and options.stream is not None:
stream = options.stream
elif "stream" in completion_data:
stream = completion_data["stream"]

completion_data["stream"] = stream

response = self.client.text_generation(**completion_data)
response_is_detailed = completion_data.get("details", False)
Expand All @@ -266,7 +300,10 @@ async def run_inference(
outputs.append(output)

prompt.outputs = outputs
return prompt.outputs

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs}))

return outputs

def get_output_text(
self,
Expand Down
60 changes: 47 additions & 13 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TextGenerationStreamResponse,
)

from aiconfig.schema import ExecuteResult, Output, Prompt
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig import CallbackEvent

from aiconfig.util.params import resolve_prompt

Expand All @@ -27,7 +28,7 @@
# Step 1: define Helpers


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
"""
Refines the completion params for the HF text generation api. Removes any unsupported params.
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
Expand Down Expand Up @@ -106,11 +107,10 @@ def construct_stream_output(
return output


def construct_regular_output(response, response_includes_details: bool) -> Output:
def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {}
data = response
if response_includes_details:
response: TextGenerationResponse # Expect response to be of type TextGenerationResponse
data = response.generated_text
metadata = {"details": response.details}

Expand Down Expand Up @@ -170,9 +170,9 @@ async def serialize(
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict] = None,
parameters: Optional[dict[Any, Any]] = None,
**kwargs,
) -> Prompt:
) -> list[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Expand All @@ -183,6 +183,19 @@ async def serialize(
Returns:
str: Serialized representation of the prompt and inference settings.
"""
await ai_config.callback_manager.run_callbacks(
CallbackEvent(
"on_serialize_start",
__name__,
{
"prompt_name": prompt_name,
"data": data,
"parameters": parameters,
"kwargs": kwargs,
},
)
)

data = copy.deepcopy(data)

# assume data is completion params for HF text generation
Expand All @@ -191,6 +204,8 @@ async def serialize(
# Prompt is handled, remove from data
data.pop("prompt", None)

prompts = []

model_metadata = ai_config.get_model_metadata(data, self.id())
prompt = Prompt(
name=prompt_name,
Expand All @@ -199,15 +214,19 @@ async def serialize(
model=model_metadata, parameters=parameters, **kwargs
),
)
return [prompt]

prompts.append(prompt)

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

return prompts

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
options,
params: Optional[Dict] = {},
) -> Dict:
params: Optional[dict[Any, Any]] = {},
) -> dict[Any, Any]:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Expand All @@ -218,6 +237,8 @@ async def deserialize(
Returns:
dict: Model-specific completion parameters.
"""
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

resolved_prompt = resolve_prompt(prompt, params, aiconfig)

# Build Completion data
Expand All @@ -227,10 +248,12 @@ async def deserialize(

completion_data["prompt"] = resolved_prompt

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig, options, parameters
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> Output:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
Expand All @@ -243,7 +266,15 @@ async def run_inference(
Returns:
InferenceResponse: The response from the model.
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
__name__,
{"prompt": prompt, "options": options, "parameters": parameters},
)
)

completion_data = await self.deserialize(prompt, aiconfig, parameters)

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
stream = True # Default value
Expand All @@ -269,7 +300,10 @@ async def run_inference(
outputs.append(output)

prompt.outputs = outputs
return prompt.outputs

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs}))

return outputs

def get_output_text(
self,
Expand Down

0 comments on commit 26c6443

Please sign in to comment.