Skip to content

Commit

Permalink
Merge branch 'main' into store_json
Browse files Browse the repository at this point in the history
  • Loading branch information
kanesoban committed Dec 11, 2024
2 parents dd5a926 + d5a2ccf commit ec989ac
Show file tree
Hide file tree
Showing 20 changed files with 1,419 additions and 443 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,7 @@ static
# database stuff
*db
*.db-shm
*.db-wal
*.db-wal

# ignore csvs in root dir
*.csv
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Integrated Alembic for managing chat history migrations
- Tool implementations without langchain or langgraph dependencies
- CRUDs.
- BlueNaas CRUD tools

### Fixed
- Migrate LLM Evaluation logic to scripts and add tests

## [0.3.3] - 30.10.2024

Expand Down
258 changes: 258 additions & 0 deletions src/neuroagent/scripts/avalidate_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""Run validation on tool calls."""

import argparse
import asyncio
import json
import logging
from typing import Any

import aiohttp
import pandas as pd

logging.basicConfig(
level=logging.INFO, # Set the logging level
format="%(asctime)s - %(levelname)s - %(message)s", # Define the log message format
)


async def fetch_tool_call(
session: aiohttp.ClientSession,
query: dict[str, Any],
base_url: str,
semaphore: asyncio.Semaphore,
) -> dict[str, Any]:
"""
Fetch the tool call results for a given test case.
This function sends an asynchronous POST request to the API with the provided
test case data and validates the tool calls against expected outcomes.
Args:
----
session (aiohttp.ClientSession): The aiohttp session used to make the HTTP request.
query (dict): A dictionary containing the test case data, including the prompt,
expected tools, optional tools, and forbidden tools.
base_url (str): The base URL of the API.
Returns
-------
dict: A dictionary containing the prompt, actual tool calls, expected tool calls,
and whether the actual calls match the expected ones.
"""
async with semaphore:
prompt = query["prompt"]
expected_tool_calls = query["expected_tools"]
optional_tools = query["optional_tools"]
forbidden_tools = query["forbidden_tools"]

logging.info(f"Testing prompt: {prompt}")

# Send a request to the API
async with session.post(
f"{base_url}/qa/run",
headers={"Content-Type": "application/json"},
json={
"query": prompt,
},
) as response:
if response.status == 200:
steps = await response.json()
called_tool_names = [
step.get("tool_name", None) for step in steps.get("steps", [])
]
expected_tool_names = [
tool_call.get("tool_name", None)
for tool_call in expected_tool_calls
]
match, reason = validate_tool(
expected_tool_names,
called_tool_names,
optional_tools=optional_tools,
forbidden_tools=forbidden_tools,
)
return {
"Prompt": prompt,
"Actual": called_tool_names,
"Expected": expected_tool_names,
"Optional": optional_tools,
"Forbidden": forbidden_tools,
"Match": "Yes" if match else "No",
"Reason": reason if not match else "N/A",
}
else:
# Attempt to parse the error message from the response content
try:
error_content = await response.json()
error_message = error_content.get("content", "Unknown error")
except Exception as e:
error_message = f"Failed to parse error message: {str(e)}"

error_info = {
"status_code": response.status,
"response_content": error_message,
}
logging.error(
f"API call failed for prompt: {prompt} with error: {error_info}"
)
return {
"Prompt": prompt,
"Actual": f"API call failed: {error_info}",
"Expected": expected_tool_calls,
"Optional": optional_tools,
"Forbidden": forbidden_tools,
"Match": "No",
"Reason": f"API call failed: {error_info}",
}


async def validate_tool_calls_async(
base_url: str,
data_file: str,
output_file: str = "tool_call_evaluation.csv",
max_concurrent_requests: int = 10,
) -> None:
"""
Run asynchronous tool call tests and save the results to a CSV file.
Args:
----
base_url (str): The base URL of the API.
data_file (str): The path to the JSON file containing test case data.
output_file (str): The name of the output CSV file where the results will
be saved. Defaults to 'tool_call_evaluation.csv'.
max_concurrent_requests (int): Maximum number of concurrent API requests.
Defaults to 10.
Returns
-------
None: This function does not return any value. It writes the results to a
CSV file.
"""
with open(data_file) as f:
tool_calls_data = json.load(f)

results_list = []
semaphore = asyncio.Semaphore(max_concurrent_requests)

async with aiohttp.ClientSession() as session:
tasks = [
fetch_tool_call(session, query, base_url, semaphore)
for query in tool_calls_data
]
results_list = await asyncio.gather(*tasks)

results_df = pd.DataFrame(results_list)
results_df.to_csv(output_file, index=False)


def validate_tool(
required_tools: list[str],
actual_tool_calls: list[str],
optional_tools: list[str],
forbidden_tools: list[str],
) -> tuple[bool, str]:
"""
Validate the sequence of tool calls against required, optional, and forbidden tools.
Args:
----
required_tools (List): A list of tools that must be called in the specified order.
actual_tool_calls (List): A list of tools that were actually called.
optional_tools (List): A list of tools that can be called but are not required.
forbidden_tools (List): A list of tools that must not be called.
Returns
-------
tuple: A tuple containing a boolean and a string message. The boolean is True if the
validation is successful, otherwise False. The string message provides details
about the validation result.
"""
# Check for forbidden tools
if inter := set(actual_tool_calls) & set(forbidden_tools):
return False, f"Forbidden tool(s) called: {inter}"

# Validate required tools order
order = 0
for tool in actual_tool_calls:
if order < len(required_tools) and tool == required_tools[order]:
order += 1
elif tool in optional_tools or (
order > 0 and tool == required_tools[order - 1]
):
continue
elif tool not in required_tools[:order]:
return False, f"Unexpected tool called: {tool}"

# Check if all required tools were called
if order != len(required_tools):
return False, "Not all required tools were called"

return True, "All required tools called correctly"


def main() -> None:
"""
Execute the tool call validation process.
This function sets up the argument parser to handle command-line arguments,
specifically for specifying the base URL, port, data file path, and output
CSV file name. It then calls the validate_tool_calls_async function with
the provided arguments to perform the validation of tool calls and save
the results.
The function is designed to be the entry point when the script is run
directly from the command line.
"""
parser = argparse.ArgumentParser(
description="Run tool call tests and save results."
)
parser.add_argument(
"--base_url",
type=str,
default="http://localhost",
help="Base URL for the API",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the API",
)
parser.add_argument(
"--data",
type=str,
default="tests/data/tool_calls.json",
help="Path to the JSON file containing test case data",
)
parser.add_argument(
"--output",
type=str,
default="tool_call_evaluation.csv",
help="Output CSV file for results",
)
args = parser.parse_args()

# Construct the full base URL
full_url = f"{args.base_url}:{args.port}"
asyncio.run(validate_tool_calls_async(full_url, args.data, args.output))


if __name__ == "__main__":
"""
Validate tool calls against expected outcomes and logs the results.
The script reads a set of prompts and their expected tool calls, executes the tool calls,
and compares the actual tool calls made with the expected ones. It logs whether the actual
tool calls match the expected ones and saves the results to a CSV file.
Usage:
python validate_tool_calls.py --output <output_csv_file>
Arguments:
--output: The name of the output CSV file where the results will be saved.
Defaults to 'tool_call_evaluation.csv'.
The script is intended to be run as a standalone module.
"""

main()
5 changes: 5 additions & 0 deletions swarm_copy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
To generate the pydantic models from an API openapi.json spec, run the following command:
```bash
pip install datamodel-code-generator
datamodel-codegen --enum-field-as-literal=all --target-python-version=3.10 --use-annotated --reuse-model --input-file-type=openapi --url=TARGET_URL/openapi.json --output=OUTPUT --output-model-type=pydantic_v2.BaseModel
```
2 changes: 1 addition & 1 deletion swarm_copy/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SettingsGetMEModel(BaseModel):
class SettingsBlueNaaS(BaseModel):
"""BlueNaaS settings."""

url: str = "https://openbluebrain.com/api/bluenaas/simulation/single-neuron/run"
url: str = "https://openbluebrain.com/api/bluenaas"
model_config = ConfigDict(frozen=True)


Expand Down
19 changes: 15 additions & 4 deletions swarm_copy/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
from swarm_copy.run import AgentsRoutine
from swarm_copy.tools import (
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
GetTracesTool,
KGMorphoFeatureTool,
LiteratureSearchTool,
MEModelGetAllTool,
MEModelGetOneTool,
MorphologyFeatureTool,
ResolveEntitiesTool,
SCSGetAllTool,
SCSGetOneTool,
SCSPostTool,
)
from swarm_copy.utils import RegionMeta, get_file_from_KG

Expand Down Expand Up @@ -186,8 +190,8 @@ async def get_vlab_and_project(
}
elif not settings.keycloak.validate_token:
vlab_and_project = {
"vlab_id": "430108e9-a81d-4b13-b7b6-afca00195908",
"project_id": "eff09ea1-be16-47f0-91b6-52a3ea3ee575",
"vlab_id": "32c83739-f39c-49d1-833f-58c981ebd2a2",
"project_id": "123251a1-be18-4146-87b5-5ca2f8bfaf48",
}
else:
thread_id = request.path_params.get("thread_id")
Expand Down Expand Up @@ -237,9 +241,13 @@ def get_starting_agent(
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
tools=[
SCSGetAllTool,
SCSGetOneTool,
SCSPostTool,
MEModelGetAllTool,
MEModelGetOneTool,
LiteratureSearchTool,
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
KGMorphoFeatureTool,
MorphologyFeatureTool,
Expand All @@ -261,6 +269,8 @@ def get_context_variables(
return {
"starting_agent": starting_agent,
"token": token,
"vlab_id": "32c83739-f39c-49d1-833f-58c981ebd2a2", # New god account vlab. Replaced by actual id in endpoint for now. Meant for usage without history
"project_id": "123251a1-be18-4146-87b5-5ca2f8bfaf48", # New god account proj. Replaced by actual id in endpoint for now. Meant for usage without history
"retriever_k": settings.tools.literature.retriever_k,
"reranker_k": settings.tools.literature.reranker_k,
"use_reranker": settings.tools.literature.use_reranker,
Expand All @@ -274,6 +284,7 @@ def get_context_variables(
"trace_search_size": settings.tools.trace.search_size,
"kg_sparql_url": settings.knowledge_graph.sparql_url,
"kg_class_view_url": settings.knowledge_graph.class_view_url,
"bluenaas_url": settings.tools.bluenaas.url,
"httpx_client": httpx_client,
}

Expand Down
Loading

0 comments on commit ec989ac

Please sign in to comment.