Skip to content

Commit

Permalink
rawResponse -> raw_response for snake_case consistency (#657)
Browse files Browse the repository at this point in the history
# rawResponse -> raw_response for snake_case consistency

Just updating the `rawResponse` property name to `raw_response` for
snake_case consistency in the schema. All other changes here are just
auto-formatted

## Testing
`pytest`
```
=========================================================== 88 passed, 38 warnings in 3.28s ===========================================================
```

`yarn test`
```
(aiconfig) ryanholinshead@Ryans-MBP typescript % yarn test 
yarn run v1.22.19
$ jest --runInBand
 PASS  __tests__/parsers/hf/hf.test.ts
  HuggingFaceTextGeneration ModelParser
    ✓ uses HuggingFace API token from environment variable if it exists (9 ms)
    ✓ serializing params to config prompt (1 ms)
    ✓ serialize callbacks (1 ms)
    ✓ deserializing config prompt to params (1 ms)
    ✓ deserialize callbacks (1 ms)
    ✓ run prompt, non-streaming (4 ms)
    ✓ run prompt, streaming (2 ms)
    ✓ run callbacks (1 ms)

 PASS  __tests__/config.test.ts
  Loading an AIConfig
    ✓ loading a basic chatgpt query config (6 ms)
    ✓ loading a prompt chain (4 ms)
    ✓ deserialize and re-serialize a prompt chain (2 ms)
    ✓ serialize a prompt chain with different settings (1 ms)

 PASS  __tests__/parsers/palm-text/palm.test.ts
  PaLM Text ModelParser
    ✓ serializing params to config prompt (4 ms)
    ✓ deserializing params to config (2 ms)
    ✓ run prompt, non-streaming (87 ms)

 PASS  __tests__/testProgramaticallyCreateConfig.ts
  test Get Global Settings
    ✓ Retrieving global setting from AIConfig with 1 model (1 ms)
  ExtractOverrideSettings function
    ✓ Should return initial settings when no global settings are defined
    ✓ Should return an override when initial settings differ from global settings (1 ms)
    ✓ Should return empty override when global settings match initial settings
    ✓ Should return empty override when Global settings defined and initial settings are empty

 PASS  __tests__/testSave.ts
  AIConfigRuntime save()
    ✓ saves the config and checks if the config json doesn't have key filePath (4 ms)

Test Suites: 5 passed, 5 total
Tests:       21 passed, 21 total
Snapshots:   0 total
Time:        2.303 s
Ran all test suites.
✨  Done in 3.58s.
```

Ran `npx ts-node function-call-stream.ts` and confirmed correct output
metadata with raw_response, e.g.:
```
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "kind": "tool_calls",
            "value": [
              {
                "type": "function",
                "function": {
                  "name": "list",
                  "arguments": "{\n\"genre\": \"historical\"\n}"
                }
              }
            ]
          },
          "execution_count": 0,
          "metadata": {
            "finish_reason": "function_call",
            "raw_response": {
              "role": "assistant",
              "content": null,
              "function_call": {
                "name": "list",
                "arguments": "{\n\"genre\": \"historical\"\n}"
              }
            }
          }
        }
      ]
```
  • Loading branch information
rholinshead authored Dec 29, 2023
2 parents 164f0b0 + 32b36ce commit 8137a6c
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 282 deletions.
111 changes: 43 additions & 68 deletions extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[
return output_list


async def construct_stream_outputs(
response: "AsyncGenerateContentResponse", options: InferenceOptions
) -> list[Output]:
async def construct_stream_outputs(response: "AsyncGenerateContentResponse", options: InferenceOptions) -> list[Output]:
"""
Construct Outputs while also streaming the response with stream callback
Expand Down Expand Up @@ -132,17 +130,17 @@ async def serialize(
The data passed in is the completion params a user would use to call the Gemini API directly.
If the user wanted to call the Gemini API directly, they might do something like this:
```
model = genai.GenerativeModel('gemini-pro')
completion_params = {"contents": "Hello"}
model.generate_content(**completion_params)
# Note: The above line is the same as doing this:
# Note: The above line is the same as doing this:
model.generate_content(contents="Hello")
```
* Important: The contents field is what contains the input data. In this case, prompt input would be the contents field.
Args:
prompt (str): The prompt to be serialized.
Expand All @@ -152,7 +150,7 @@ async def serialize(
str: Serialized representation of the prompt and inference settings.
Sample Usage:
1.
1.
completion_params = {"contents": "Hello"}
serialized_prompts = await ai_config.serialize("prompt", completion_params, "gemini-pro")
Expand Down Expand Up @@ -188,7 +186,7 @@ async def serialize(
data = copy.deepcopy(data)
contents = data.pop("contents", None)

model_name = self.model.model_name[len("models/"):]
model_name = self.model.model_name[len("models/") :]
model_metadata = ai_config.get_model_metadata(data, model_name)

prompts = []
Expand All @@ -202,7 +200,9 @@ async def serialize(
# }
contents_is_role_dict = isinstance(contents, dict) and "role" in contents and "parts"
# Multi Turn means that the contents is a list of dicts with alternating role and parts. See for more info: https://ai.google.dev/tutorials/python_quickstart#multi-turn_conversations
contents_is_multi_turn = isinstance(contents, list) and all(isinstance(item, dict) and "role" in item and "parts" in item for item in contents)
contents_is_multi_turn = isinstance(contents, list) and all(
isinstance(item, dict) and "role" in item and "parts" in item for item in contents
)

if contents is None:
raise ValueError("No contents found in data. Gemini api request requires a contents field")
Expand All @@ -226,14 +226,21 @@ async def serialize(
outputs = [
ExecuteResult(
**{
"output_type": "execute_result",
"output_type": "execute_result",
"data": model_message_parts[0],
"metadata": {"rawResponse": model_message},
"metadata": {"raw_response": model_message},
}
)
]
i += 1
prompt = Prompt(**{"name": f'{prompt_name}_{len(prompts) + 1}', "input": user_message_parts, "metadata": {"model": model_metadata}, "outputs": outputs})
prompt = Prompt(
**{
"name": f"{prompt_name}_{len(prompts) + 1}",
"input": user_message_parts,
"metadata": {"model": model_metadata},
"outputs": outputs,
}
)
prompts.append(prompt)
i += 1
else:
Expand All @@ -255,11 +262,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
Returns:
dict: Model-specific completion parameters.
"""
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_start", __name__, {"prompt": prompt, "params": params}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
Expand All @@ -270,9 +273,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
messages = self._construct_chat_history(prompt, aiconfig, params)

resolved_prompt = resolve_prompt(prompt, params, aiconfig)

messages.append({"role": "user", "parts": [{"text": resolved_prompt}]})

completion_data["contents"] = messages
else:
# If contents is already set, do not construct chat history. TODO: @Ankush-lastmile
Expand All @@ -287,15 +290,13 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# This is checking attributes and not a dict like object. in schema.py, PromptInput allows arbitrary attributes/data, and gets serialized as an attribute because it is a pydantic type
if not hasattr(prompt_input, "contents"):
# The source code show cases this more than the docs. This curl request docs similar to python sdk: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#request_body
raise ValueError("Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API")
raise ValueError(
"Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API"
)

completion_data['contents'] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)
completion_data["contents"] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params)

await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_deserialize_complete", __name__, {"output": completion_data}
)
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
return completion_data

async def run_inference(
Expand Down Expand Up @@ -347,9 +348,7 @@ async def run_inference(
outputs = construct_regular_outputs(response)

prompt.outputs = outputs
await aiconfig.callback_manager.run_callbacks(
CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})
)
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
return prompt.outputs

def get_output_text(
Expand All @@ -375,27 +374,21 @@ def get_output_text(
# We use this method in the model parser functionality itself,
# so we must error to let the user know non-string
# formats are not supported
error_message = \
"""
error_message = """
We currently only support chats where prior messages from a user are only a
single message string instance, please see docstring for more details:
https://github.com/lastmile-ai/aiconfig/blob/v1.1.8/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py#L31-L33
"""
raise ValueError(error_message)

def _construct_chat_history(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict
) -> List:
def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict) -> List:
"""
Constructs the chat history for the model
"""
messages = []
# Default to always use chat context
remember_chat_context = not hasattr(
prompt.metadata, "remember_chat_context"
) or (
hasattr(prompt.metadata, "remember_chat_context")
and prompt.metadata.remember_chat_context != False
remember_chat_context = not hasattr(prompt.metadata, "remember_chat_context") or (
hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False
)
if remember_chat_context:
# handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output
Expand All @@ -404,29 +397,21 @@ def _construct_chat_history(
if previous_prompt.name == prompt.name:
break

previous_prompt_is_same_model = aiconfig.get_model_name(
previous_prompt
) == aiconfig.get_model_name(prompt)
previous_prompt_is_same_model = aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt)
if previous_prompt_is_same_model:
previous_prompt_template = resolve_prompt(
previous_prompt, params, aiconfig
)
previous_prompt_template = resolve_prompt(previous_prompt, params, aiconfig)
previous_prompt_output = aiconfig.get_latest_output(previous_prompt)
previous_prompt_output_text = self.get_output_text(
previous_prompt, aiconfig, previous_prompt_output
)
previous_prompt_output_text = self.get_output_text(previous_prompt, aiconfig, previous_prompt_output)

messages.append(
{"role": "user", "parts": [{"text": previous_prompt_template}]}
)
messages.append({"role": "user", "parts": [{"text": previous_prompt_template}]})
messages.append(
{
"role": "model",
"parts": [{"text": previous_prompt_output_text}],
}
)

return messages
return messages

def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str:
"""
Expand All @@ -445,26 +430,18 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st
elif isinstance(contents, list):
return " ".join(contents)
elif isinstance(contents, dict):
parts= contents["parts"]
parts = contents["parts"]
if isinstance(parts, str):
return parts
elif isinstance(parts, list):
return " ".join(parts)
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")



else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")


def refine_chat_completion_params(model_settings):
Expand Down Expand Up @@ -522,9 +499,7 @@ def contains_prompt_template(prompt: Prompt):
"""
Check if a prompt's input is a valid string.
"""
return isinstance(prompt.input, str) or (
hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)
)
return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str))


AIConfigRuntime.register_model_parser(GeminiModelParser("gemini-pro"), "gemini-pro")
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
2 changes: 1 addition & 1 deletion extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ function constructOutput(response: TextGenerationOutput): Output {
output_type: "execute_result",
data: response.generated_text,
execution_count: 0,
metadata: { rawResponse: response },
metadata: { raw_response: response },
} as ExecuteResult;

return output;
Expand Down
18 changes: 7 additions & 11 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def construct_stream_output(


def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"rawResponse": response}
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

output = ExecuteResult(
**{
"output_type": "execute_result",
"data": response.generated_text or '',
"data": response.generated_text or "",
"execution_count": 0,
"metadata": metadata,
}
Expand Down Expand Up @@ -209,15 +209,13 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)

prompts.append(prompt)
await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts }))

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))

return prompts

async def deserialize(
Expand Down Expand Up @@ -251,9 +249,7 @@ async def deserialize(

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand Down
Loading

0 comments on commit 8137a6c

Please sign in to comment.