diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index e1d3cbb49..691c0c454 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -1,6 +1,10 @@ import copy from typing import TYPE_CHECKING, Any, Dict, Optional +from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.model_parser import InferenceOptions +from aiconfig.util.config_utils import get_api_key_from_environment + # HuggingFace API imports from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import ( @@ -8,18 +12,13 @@ TextGenerationStreamResponse, ) +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata +from aiconfig import CallbackEvent + +from aiconfig.util.params import resolve_prompt + # ModelParser Utils # Type hint imports -from aiconfig import ( - ExecuteResult, - InferenceOptions, - Output, - ParameterizedModelParser, - Prompt, - PromptMetadata, - get_api_key_from_environment, - resolve_prompt, -) # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -29,7 +28,7 @@ # Step 1: define Helpers -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]: """ Refines the completion params for the HF text generation api. Removes any unsupported params. The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()` @@ -93,7 +92,8 @@ def construct_stream_output( index = 0 # HF Text Generation api doesn't support multiple outputs delta = data - options.stream_callback(delta, accumulated_message, index) + if options and options.stream_callback: + options.stream_callback(delta, accumulated_message, index) output = ExecuteResult( **{ @@ -107,11 +107,10 @@ def construct_stream_output( return output -def construct_regular_output(response, response_includes_details: bool) -> Output: +def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: metadata = {} data = response if response_includes_details: - response: TextGenerationResponse # Expect response to be of type TextGenerationResponse data = response.generated_text metadata = {"details": response.details} @@ -126,9 +125,9 @@ def construct_regular_output(response, response_includes_details: bool) -> Outpu return output -class HuggingFaceTextGenerationClient(ParameterizedModelParser): +class HuggingFaceTextGenerationParser(ParameterizedModelParser): """ - A model parser for HuggingFace text generation models using inference client API. + A model parser for HuggingFace text generation models. """ def __init__(self, model_id: str = None, use_api_token=False): @@ -138,12 +137,12 @@ def __init__(self, model_id: str = None, use_api_token=False): no_token (bool): Whether or not to require an API token. Set to False if you don't have an api key. Returns: - HuggingFaceTextGenerationClient: The HuggingFaceTextGenerationClient object. + HuggingFaceTextParser: The HuggingFaceTextParser object. Usage: 1. Create a new model parser object with the model ID of the model to use. - parser = HuggingFaceTextGenerationClient("mistralai/Mistral-7B-Instruct-v0.1", use_api_token=False) + parser = HuggingFaceTextParser("mistralai/Mistral-7B-Instruct-v0.1", use_api_token=False) 2. Add the model parser to the registry. config.register_model_parser(parser) @@ -166,14 +165,14 @@ def id(self) -> str: """ return "HuggingFaceTextGenerationClient" - def serialize( + async def serialize( self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", - parameters: Optional[Dict] = None, - **kwargs - ) -> Prompt: + parameters: Optional[dict[Any, Any]] = None, + **kwargs, + ) -> list[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -184,6 +183,19 @@ def serialize( Returns: str: Serialized representation of the prompt and inference settings. """ + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_start", + __name__, + { + "prompt_name": prompt_name, + "data": data, + "parameters": parameters, + "kwargs": kwargs, + }, + ) + ) + data = copy.deepcopy(data) # assume data is completion params for HF text generation @@ -192,7 +204,9 @@ def serialize( # Prompt is handled, remove from data data.pop("prompt", None) - model_metadata = ai_config.generate_model_metadata(data, self.id()) + prompts = [] + + model_metadata = ai_config.get_model_metadata(data, self.id()) prompt = Prompt( name=prompt_name, input=prompt_input, @@ -200,15 +214,19 @@ def serialize( model=model_metadata, parameters=parameters, **kwargs ), ) - return [prompt] + + prompts.append(prompt) + + await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts })) + + return prompts async def deserialize( self, prompt: Prompt, aiconfig: "AIConfigRuntime", - options, - params: Optional[Dict] = {}, - ) -> Dict: + params: Optional[dict[Any, Any]] = {}, + ) -> dict[Any, Any]: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -219,6 +237,8 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + resolved_prompt = resolve_prompt(prompt, params, aiconfig) # Build Completion data @@ -228,10 +248,12 @@ async def deserialize( completion_data["prompt"] = resolved_prompt + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + return completion_data async def run_inference( - self, prompt: Prompt, aiconfig, options, parameters + self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] ) -> Output: """ Invoked to run a prompt in the .aiconfig. This method should perform @@ -244,12 +266,24 @@ async def run_inference( Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_start", + __name__, + {"prompt": prompt, "options": options, "parameters": parameters}, + ) + ) + + completion_data = await self.deserialize(prompt, aiconfig, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + stream = True # Default value + if options is not None and options.stream is not None: + stream = options.stream + elif "stream" in completion_data: + stream = completion_data["stream"] + + completion_data["stream"] = stream response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) @@ -266,7 +300,10 @@ async def run_inference( outputs.append(output) prompt.outputs = outputs - return prompt.outputs + + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs})) + + return outputs def get_output_text( self, diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index 5154e0993..af224c69e 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -12,7 +12,8 @@ TextGenerationStreamResponse, ) -from aiconfig.schema import ExecuteResult, Output, Prompt +from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata +from aiconfig import CallbackEvent from aiconfig.util.params import resolve_prompt @@ -27,7 +28,7 @@ # Step 1: define Helpers -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]: """ Refines the completion params for the HF text generation api. Removes any unsupported params. The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()` @@ -106,11 +107,10 @@ def construct_stream_output( return output -def construct_regular_output(response, response_includes_details: bool) -> Output: +def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: metadata = {} data = response if response_includes_details: - response: TextGenerationResponse # Expect response to be of type TextGenerationResponse data = response.generated_text metadata = {"details": response.details} @@ -170,9 +170,9 @@ async def serialize( prompt_name: str, data: Any, ai_config: "AIConfigRuntime", - parameters: Optional[Dict] = None, + parameters: Optional[dict[Any, Any]] = None, **kwargs, - ) -> Prompt: + ) -> list[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -183,6 +183,19 @@ async def serialize( Returns: str: Serialized representation of the prompt and inference settings. """ + await ai_config.callback_manager.run_callbacks( + CallbackEvent( + "on_serialize_start", + __name__, + { + "prompt_name": prompt_name, + "data": data, + "parameters": parameters, + "kwargs": kwargs, + }, + ) + ) + data = copy.deepcopy(data) # assume data is completion params for HF text generation @@ -191,6 +204,8 @@ async def serialize( # Prompt is handled, remove from data data.pop("prompt", None) + prompts = [] + model_metadata = ai_config.get_model_metadata(data, self.id()) prompt = Prompt( name=prompt_name, @@ -199,15 +214,19 @@ async def serialize( model=model_metadata, parameters=parameters, **kwargs ), ) - return [prompt] + + prompts.append(prompt) + + await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts })) + + return prompts async def deserialize( self, prompt: Prompt, aiconfig: "AIConfigRuntime", - options, - params: Optional[Dict] = {}, - ) -> Dict: + params: Optional[dict[Any, Any]] = {}, + ) -> dict[Any, Any]: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -218,6 +237,8 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) + resolved_prompt = resolve_prompt(prompt, params, aiconfig) # Build Completion data @@ -227,10 +248,12 @@ async def deserialize( completion_data["prompt"] = resolved_prompt + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) + return completion_data async def run_inference( - self, prompt: Prompt, aiconfig, options, parameters + self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] ) -> Output: """ Invoked to run a prompt in the .aiconfig. This method should perform @@ -243,7 +266,15 @@ async def run_inference( Returns: InferenceResponse: The response from the model. """ - completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + await aiconfig.callback_manager.run_callbacks( + CallbackEvent( + "on_run_start", + __name__, + {"prompt": prompt, "options": options, "parameters": parameters}, + ) + ) + + completion_data = await self.deserialize(prompt, aiconfig, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. stream = True # Default value @@ -269,7 +300,10 @@ async def run_inference( outputs.append(output) prompt.outputs = outputs - return prompt.outputs + + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": outputs})) + + return outputs def get_output_text( self,