diff --git a/notebooks/agents/sql_agent/Chinook.db b/notebooks/agents/sql_agent/Chinook.db new file mode 100644 index 00000000..f122837c Binary files /dev/null and b/notebooks/agents/sql_agent/Chinook.db differ diff --git a/notebooks/agents/sql_agent/README.md b/notebooks/agents/sql_agent/README.md new file mode 100644 index 00000000..164ed3e6 --- /dev/null +++ b/notebooks/agents/sql_agent/README.md @@ -0,0 +1 @@ +In this notebook we explore how to setup a Cohere ReAct Agent to answer questions over SQL Databases using Langchain’s existing SQLDBToolkit. \ No newline at end of file diff --git a/notebooks/agents/sql_agent/sql_agent.ipynb b/notebooks/agents/sql_agent/sql_agent.ipynb new file mode 100644 index 00000000..62ec38d8 --- /dev/null +++ b/notebooks/agents/sql_agent/sql_agent.ipynb @@ -0,0 +1,733 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5c0d5271-8d35-4471-b667-265695cbf11a", + "metadata": {}, + "source": [ + "![Cohere-Logo-Color-RGB.png]()" + ] + }, + { + "cell_type": "markdown", + "id": "1d71f595-01ef-4c24-99eb-81b01eec048b", + "metadata": {}, + "source": [ + "# Notebook Overview\n", + "\n", + "## Motivation\n", + "Enterprise customers often store and handle information in relational databases but querying such databases effectively requires bespoke knowledge of the underlying database structure as well as strong SQL coding skills. One way to address these challenges is to build an LLM agent capable of generating and executing SQL queries based on natural language. For example If a user asks: ``what are the top 4 rows in table X``, the agent should be able to generate ``SELECT * FROM X LIMIT 4``, execute this query and return the output to the user. \n", + "\n", + "## Objective\n", + "In this notebook we explore how to setup a [Cohere ReAct Agent](https://github.com/langchain-ai/langchain-cohere/blob/main/libs/cohere/langchain_cohere/cohere_agent.py) to answer questions over SQL Databases. We show how this can be done seamlessly with langchain's existing SQLDBToolkit.\n", + "\n", + "## Table of Contents\n", + "\n", + "- [Toolkit Setup](#sec_step0)\n", + "- [SQL Agent](#sec_step1)\n", + "- [SQL Agent with context](#sec_step2)" + ] + }, + { + "cell_type": "markdown", + "id": "68e373fd", + "metadata": { + "collapsed": false + }, + "source": [ + "\n", + "# Toolkit Setup\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "bce0ca82-fc00-40bb-9d0e-a274bd2eb8a7", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bce0ca82-fc00-40bb-9d0e-a274bd2eb8a7", + "outputId": "c6c45466-2abc-40fb-8e0e-823d5339e902" + }, + "outputs": [], + "source": [ + "from langchain.agents import AgentExecutor\n", + "from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_cohere.chat_models import ChatCohere\n", + "from langchain_community.utilities.sql_database import SQLDatabase\n", + "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "76287bf9-d98d-4a34-84c8-46134b85de53", + "metadata": {}, + "outputs": [], + "source": [ + "####################################################################################################\n", + "#\n", + "# Uncomment if you need to install the following packages\n", + "#\n", + "####################################################################################################\n", + "\n", + "#!pip install --quiet langchain langchain_cohere langchain_experimental --upgrade" + ] + }, + { + "cell_type": "markdown", + "id": "c8e8405d-080c-4426-aa94-57d975727dae", + "metadata": {}, + "source": [ + "Langchain already has a SQLDBToolkit that consists of 4 tools to handle SQL query generation, execution and validation. To use this, you simply need to define your LLM and DB and pass these into the Toolkit.\n", + "\n", + "These are the following tools:\n", + "- 'sql_db_query': executes SQL code on the database\n", + "- 'sql_db_schema': returns the schema of tables given the list of tables\n", + "- 'sql_db_list_tables': lists the tables in the database\n", + "- 'sql_db_query_checker': validates the SQL query\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6195f401-f0be-4a31-ab9f-80e722bb0fbb", + "metadata": {}, + "outputs": [], + "source": [ + "# load the cohere api key\n", + "os.environ[\"COHERE_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "936212a7-fda6-4e12-a5e3-2fdd373660ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "**List of pre-defined Langchain Tools**\n", + "['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']\n", + "\n", + "**Context to pass to LLM on tables**\n", + "{'table_info': '\\nCREATE TABLE \"Album\" (\\n\\t\"AlbumId\" INTEGER NOT NULL, \\n\\t\"Title\" NVARCHAR(160) NOT NULL, \\n\\t\"ArtistId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"AlbumId\"), \\n\\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\\n)\\n\\n/*\\n3 rows from Album table:\\nAlbumId\\tTitle\\tArtistId\\n1\\tFor Those About To Rock We Salute You\\t1\\n2\\tBalls to the Wall\\t2\\n3\\tRestless and Wild\\t2\\n*/\\n\\n\\nCREATE TABLE \"Artist\" (\\n\\t\"ArtistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"ArtistId\")\\n)\\n\\n/*\\n3 rows from Artist table:\\nArtistId\\tName\\n1\\tAC/DC\\n2\\tAccept\\n3\\tAerosmith\\n*/\\n\\n\\nCREATE TABLE \"Customer\" (\\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"FirstName\" NVARCHAR(40) NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"Company\" NVARCHAR(80), \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60) NOT NULL, \\n\\t\"SupportRepId\" INTEGER, \\n\\tPRIMARY KEY (\"CustomerId\"), \\n\\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Customer table:\\nCustomerId\\tFirstName\\tLastName\\tCompany\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\tSupportRepId\\n1\\tLuís\\tGonçalves\\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\\tAv. Brigadeiro Faria Lima, 2170\\tSão José dos Campos\\tSP\\tBrazil\\t12227-000\\t+55 (12) 3923-5555\\t+55 (12) 3923-5566\\tluisg@embraer.com.br\\t3\\n2\\tLeonie\\tKöhler\\tNone\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t+49 0711 2842222\\tNone\\tleonekohler@surfeu.de\\t5\\n3\\tFrançois\\tTremblay\\tNone\\t1498 rue Bélanger\\tMontréal\\tQC\\tCanada\\tH2G 1A7\\t+1 (514) 721-4711\\tNone\\tftremblay@gmail.com\\t3\\n*/\\n\\n\\nCREATE TABLE \"Employee\" (\\n\\t\"EmployeeId\" INTEGER NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"FirstName\" NVARCHAR(20) NOT NULL, \\n\\t\"Title\" NVARCHAR(30), \\n\\t\"ReportsTo\" INTEGER, \\n\\t\"BirthDate\" DATETIME, \\n\\t\"HireDate\" DATETIME, \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60), \\n\\tPRIMARY KEY (\"EmployeeId\"), \\n\\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Employee table:\\nEmployeeId\\tLastName\\tFirstName\\tTitle\\tReportsTo\\tBirthDate\\tHireDate\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\n1\\tAdams\\tAndrew\\tGeneral Manager\\tNone\\t1962-02-18 00:00:00\\t2002-08-14 00:00:00\\t11120 Jasper Ave NW\\tEdmonton\\tAB\\tCanada\\tT5K 2N1\\t+1 (780) 428-9482\\t+1 (780) 428-3457\\tandrew@chinookcorp.com\\n2\\tEdwards\\tNancy\\tSales Manager\\t1\\t1958-12-08 00:00:00\\t2002-05-01 00:00:00\\t825 8 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 2T3\\t+1 (403) 262-3443\\t+1 (403) 262-3322\\tnancy@chinookcorp.com\\n3\\tPeacock\\tJane\\tSales Support Agent\\t2\\t1973-08-29 00:00:00\\t2002-04-01 00:00:00\\t1111 6 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 5M5\\t+1 (403) 262-3443\\t+1 (403) 262-6712\\tjane@chinookcorp.com\\n*/\\n\\n\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Invoice\" (\\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"InvoiceDate\" DATETIME NOT NULL, \\n\\t\"BillingAddress\" NVARCHAR(70), \\n\\t\"BillingCity\" NVARCHAR(40), \\n\\t\"BillingState\" NVARCHAR(40), \\n\\t\"BillingCountry\" NVARCHAR(40), \\n\\t\"BillingPostalCode\" NVARCHAR(10), \\n\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceId\"), \\n\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\n)\\n\\n/*\\n3 rows from Invoice table:\\nInvoiceId\\tCustomerId\\tInvoiceDate\\tBillingAddress\\tBillingCity\\tBillingState\\tBillingCountry\\tBillingPostalCode\\tTotal\\n1\\t2\\t2021-01-01 00:00:00\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t1.98\\n2\\t4\\t2021-01-02 00:00:00\\tUllevålsveien 14\\tOslo\\tNone\\tNorway\\t0171\\t3.96\\n3\\t8\\t2021-01-03 00:00:00\\tGrétrystraat 63\\tBrussels\\tNone\\tBelgium\\t1000\\t5.94\\n*/\\n\\n\\nCREATE TABLE \"InvoiceLine\" (\\n\\t\"InvoiceLineId\" INTEGER NOT NULL, \\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\n\\t\"Quantity\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceLineId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\\n)\\n\\n/*\\n3 rows from InvoiceLine table:\\nInvoiceLineId\\tInvoiceId\\tTrackId\\tUnitPrice\\tQuantity\\n1\\t1\\t2\\t0.99\\t1\\n2\\t1\\t4\\t0.99\\t1\\n3\\t2\\t6\\t0.99\\t1\\n*/\\n\\n\\nCREATE TABLE \"MediaType\" (\\n\\t\"MediaTypeId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"MediaTypeId\")\\n)\\n\\n/*\\n3 rows from MediaType table:\\nMediaTypeId\\tName\\n1\\tMPEG audio file\\n2\\tProtected AAC audio file\\n3\\tProtected MPEG-4 video file\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/\\n\\n\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from PlaylistTrack table:\\nPlaylistId\\tTrackId\\n1\\t3402\\n1\\t3389\\n1\\t3390\\n*/\\n\\n\\nCREATE TABLE \"Track\" (\\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(200) NOT NULL, \\n\\t\"AlbumId\" INTEGER, \\n\\t\"MediaTypeId\" INTEGER NOT NULL, \\n\\t\"GenreId\" INTEGER, \\n\\t\"Composer\" NVARCHAR(220), \\n\\t\"Milliseconds\" INTEGER NOT NULL, \\n\\t\"Bytes\" INTEGER, \\n\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"TrackId\"), \\n\\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \\n\\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \\n\\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\\n)\\n\\n/*\\n3 rows from Track table:\\nTrackId\\tName\\tAlbumId\\tMediaTypeId\\tGenreId\\tComposer\\tMilliseconds\\tBytes\\tUnitPrice\\n1\\tFor Those About To Rock (We Salute You)\\t1\\t1\\t1\\tAngus Young, Malcolm Young, Brian Johnson\\t343719\\t11170334\\t0.99\\n2\\tBalls to the Wall\\t2\\t2\\t1\\tU. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann\\t342562\\t5510424\\t0.99\\n3\\tFast As a Shark\\t3\\t2\\t1\\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\\t230619\\t3990994\\t0.99\\n*/', 'table_names': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}\n" + ] + } + ], + "source": [ + "\n", + "DB_NAME='Chinook.db'\n", + "MODEL=\"command-r-plus\"\n", + "\n", + "llm = ChatCohere(model=MODEL, temperature=0.1,verbose=True)\n", + "db = SQLDatabase.from_uri(f\"sqlite:///{DB_NAME}\")\n", + "\n", + "\n", + "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n", + "context = toolkit.get_context()\n", + "tools = toolkit.get_tools()\n", + "\n", + "print('**List of pre-defined Langchain Tools**')\n", + "print([tool.name for tool in tools])\n", + "print('')\n", + "print('**Context to pass to LLM on tables**')\n", + "print(context)" + ] + }, + { + "cell_type": "markdown", + "id": "7554c397-76ae-46c8-9c76-b906f926febf", + "metadata": {}, + "source": [ + "\n", + "# SQL Agent\n", + "\n", + "we follow the general cohere react agent setup in Langchain to build our SQL agent." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9be6f9eb-b1ba-4a4f-a5f3-5b775f0db69f", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# define the prompt template\n", + "prompt = ChatPromptTemplate.from_template(\"{input}\")\n", + "\n", + "# instantiate the ReAct agent\n", + "agent = create_cohere_react_agent(\n", + " llm=llm,\n", + " tools=tools,\n", + " prompt=prompt,\n", + ")\n", + "agent_executor = AgentExecutor(agent=agent,\n", + " tools=tools,\n", + " verbose=True,\n", + " return_intermediate_steps=True\n", + " )\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5ebb2d00-315d-462c-919b-749027133d89", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\n", + "I will use the sql_db_list_tables tool to find out what tables are available.\n", + "{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\n", + "\u001B[0m\u001B[38;5;200m\u001B[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001B[0m\u001B[32;1m\u001B[1;3mRelevant Documents: 0\n", + "Cited Documents: 0\n", + "Answer: The following tables are available: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track.\n", + "Grounded answer: The following tables are available: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track.\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'what tables are available?',\n", + " 'output': 'The following tables are available: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track.',\n", + " 'citations': [CohereCitation(start=36, end=41, text='Album', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=43, end=49, text='Artist', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=51, end=59, text='Customer', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=61, end=69, text='Employee', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=71, end=76, text='Genre', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=78, end=85, text='Invoice', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=87, end=98, text='InvoiceLine', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=100, end=109, text='MediaType', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=111, end=119, text='Playlist', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=121, end=134, text='PlaylistTrack', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}]),\n", + " CohereCitation(start=136, end=141, text='Track', documents=[{'output': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}])],\n", + " 'intermediate_steps': [(AgentActionMessageLog(tool='sql_db_list_tables', tool_input={'tool_input': ''}, log=\"\\nI will use the sql_db_list_tables tool to find out what tables are available.\\n{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\\n\", message_log=[AIMessage(content='\\nPlan: I will use the sql_db_list_tables tool to find out what tables are available.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_list_tables\",\\n \"parameters\": {\\n \"tool_input\": \"\"\\n }\\n }\\n]\\n```')]),\n", + " 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track')]}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\n", + " \"input\": 'what tables are available?',\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "8a1fc000", + "metadata": { + "collapsed": false + }, + "source": [ + "The agent uses the list_tables tool to effectively highlight all the tables in the DB" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "77a3f5ac", + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\n", + "I will use the sql_db_schema tool to find the first row of the Playlist and Genre tables.\n", + "{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Playlist, Genre'}}\n", + "\u001B[0m\u001B[33;1m\u001B[1;3m\n", + "CREATE TABLE \"Genre\" (\n", + "\t\"GenreId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"GenreId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Genre table:\n", + "GenreId\tName\n", + "1\tRock\n", + "2\tJazz\n", + "3\tMetal\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Playlist\" (\n", + "\t\"PlaylistId\" INTEGER NOT NULL, \n", + "\t\"Name\" NVARCHAR(120), \n", + "\tPRIMARY KEY (\"PlaylistId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Playlist table:\n", + "PlaylistId\tName\n", + "1\tMusic\n", + "2\tMovies\n", + "3\tTV Shows\n", + "*/\u001B[0m\u001B[32;1m\u001B[1;3mRelevant Documents: 0\n", + "Cited Documents: 0\n", + "Answer: Here is the first row of the Genre table:\n", + "\n", + "| GenreId | Name |\n", + "| --- | --- |\n", + "| 1 | Rock |\n", + "\n", + "Here is the first row of the Playlist table:\n", + "\n", + "| PlaylistId | Name |\n", + "| --- | --- |\n", + "| 1 | Music |\n", + "Grounded answer: Here is the first row of the Genre table:\n", + "\n", + "| GenreId | Name |\n", + "| --- | --- |\n", + "| 1 | Rock |\n", + "\n", + "Here is the first row of the Playlist table:\n", + "\n", + "| PlaylistId | Name |\n", + "| --- | --- |\n", + "| 1 | Music |\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'show the first row of the Playlist and Genre tables?',\n", + " 'output': 'Here is the first row of the Genre table:\\n\\n| GenreId | Name |\\n| --- | --- |\\n| 1 | Rock |\\n\\nHere is the first row of the Playlist table:\\n\\n| PlaylistId | Name |\\n| --- | --- |\\n| 1 | Music |',\n", + " 'citations': [CohereCitation(start=45, end=52, text='GenreId', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=55, end=59, text='Name', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=78, end=79, text='1', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=82, end=86, text='Rock', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=138, end=148, text='PlaylistId', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=151, end=155, text='Name', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=174, end=175, text='1', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}]),\n", + " CohereCitation(start=178, end=183, text='Music', documents=[{'output': '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/'}])],\n", + " 'intermediate_steps': [(AgentActionMessageLog(tool='sql_db_schema', tool_input={'table_names': 'Playlist, Genre'}, log=\"\\nI will use the sql_db_schema tool to find the first row of the Playlist and Genre tables.\\n{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Playlist, Genre'}}\\n\", message_log=[AIMessage(content='\\nPlan: I will use the sql_db_schema tool to find the first row of the Playlist and Genre tables.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_schema\",\\n \"parameters\": {\\n \"table_names\": \"Playlist, Genre\"\\n }\\n }\\n]\\n```')]),\n", + " '\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/')]}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\n", + " \"input\": 'show the first row of the Playlist and Genre tables?',\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "3e77d393-7fd8-42df-8ed2-59badb3b6d5b", + "metadata": {}, + "source": [ + "Here we see that the tool takes a list of tables to query the sql_db_schema tool to retrieve the various schemas" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "4466e2b1-8cd7-4e50-aab4-8a6d9f809491", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\n", + "I will search for the number of invoices per country and then write an answer based on the results.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3mError: (sqlite3.OperationalError) no such table: invoices\n", + "[SQL: SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC]\n", + "(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I received an error message. I will now use the sql_db_list_tables tool to find out the name of the table containing the invoices data.\n", + "{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\n", + "\u001B[0m\u001B[38;5;200m\u001B[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I found out that the table name is 'Invoice'. I will now rewrite my query and run it again.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3mError: (sqlite3.OperationalError) no such column: country\n", + "[SQL: SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC]\n", + "(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I received another error message. I will now use the sql_db_schema tool to find out the column names in the 'Invoice' table.\n", + "{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Invoice'}}\n", + "\u001B[0m\u001B[33;1m\u001B[1;3m\n", + "CREATE TABLE \"Invoice\" (\n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"InvoiceDate\" DATETIME NOT NULL, \n", + "\t\"BillingAddress\" NVARCHAR(70), \n", + "\t\"BillingCity\" NVARCHAR(40), \n", + "\t\"BillingState\" NVARCHAR(40), \n", + "\t\"BillingCountry\" NVARCHAR(40), \n", + "\t\"BillingPostalCode\" NVARCHAR(10), \n", + "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceId\"), \n", + "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Invoice table:\n", + "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", + "1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", + "2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", + "3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", + "*/\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I found out that the column name for country is 'BillingCountry'. I will now rewrite my query and run it again.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT BillingCountry AS country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3m[('USA', 91), ('Canada', 56), ('France', 35), ('Brazil', 35), ('Germany', 28), ('United Kingdom', 21), ('Portugal', 14), ('Czech Republic', 14), ('India', 13), ('Sweden', 7), ('Spain', 7), ('Poland', 7), ('Norway', 7), ('Netherlands', 7), ('Italy', 7), ('Ireland', 7), ('Hungary', 7), ('Finland', 7), ('Denmark', 7), ('Chile', 7), ('Belgium', 7), ('Austria', 7), ('Australia', 7), ('Argentina', 7)]\u001B[0m\u001B[32;1m\u001B[1;3mRelevant Documents: 1,3,4\n", + "Cited Documents: 4\n", + "Answer: The countries with the most invoices are the USA (91), Canada (56), and France (35).\n", + "Grounded answer: The countries with the most invoices are the USA (91), Canada (56), and France (35).\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'which countries have the most invoices?',\n", + " 'output': 'The countries with the most invoices are the USA (91), Canada (56), and France (35).',\n", + " 'citations': [CohereCitation(start=45, end=52, text='USA (91', documents=[{'output': \"[('USA', 91), ('Canada', 56), ('France', 35), ('Brazil', 35), ('Germany', 28), ('United Kingdom', 21), ('Portugal', 14), ('Czech Republic', 14), ('India', 13), ('Sweden', 7), ('Spain', 7), ('Poland', 7), ('Norway', 7), ('Netherlands', 7), ('Italy', 7), ('Ireland', 7), ('Hungary', 7), ('Finland', 7), ('Denmark', 7), ('Chile', 7), ('Belgium', 7), ('Austria', 7), ('Australia', 7), ('Argentina', 7)]\"}]),\n", + " CohereCitation(start=55, end=65, text='Canada (56', documents=[{'output': \"[('USA', 91), ('Canada', 56), ('France', 35), ('Brazil', 35), ('Germany', 28), ('United Kingdom', 21), ('Portugal', 14), ('Czech Republic', 14), ('India', 13), ('Sweden', 7), ('Spain', 7), ('Poland', 7), ('Norway', 7), ('Netherlands', 7), ('Italy', 7), ('Ireland', 7), ('Hungary', 7), ('Finland', 7), ('Denmark', 7), ('Chile', 7), ('Belgium', 7), ('Austria', 7), ('Australia', 7), ('Argentina', 7)]\"}]),\n", + " CohereCitation(start=72, end=82, text='France (35', documents=[{'output': \"[('USA', 91), ('Canada', 56), ('France', 35), ('Brazil', 35), ('Germany', 28), ('United Kingdom', 21), ('Portugal', 14), ('Czech Republic', 14), ('India', 13), ('Sweden', 7), ('Spain', 7), ('Poland', 7), ('Norway', 7), ('Netherlands', 7), ('Italy', 7), ('Ireland', 7), ('Hungary', 7), ('Finland', 7), ('Denmark', 7), ('Chile', 7), ('Belgium', 7), ('Austria', 7), ('Australia', 7), ('Argentina', 7)]\"}])],\n", + " 'intermediate_steps': [(AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC'}, log=\"\\nI will search for the number of invoices per country and then write an answer based on the results.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC'}}\\n\", message_log=[AIMessage(content='\\nPlan: I will search for the number of invoices per country and then write an answer based on the results.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC\"\\n }\\n }\\n]\\n```')]),\n", + " 'Error: (sqlite3.OperationalError) no such table: invoices\\n[SQL: SELECT country, COUNT(*) AS invoice_count FROM invoices GROUP BY country ORDER BY invoice_count DESC]\\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'),\n", + " (AgentActionMessageLog(tool='sql_db_list_tables', tool_input={'tool_input': ''}, log=\"\\nI received an error message. I will now use the sql_db_list_tables tool to find out the name of the table containing the invoices data.\\n{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\\n\", message_log=[AIMessage(content='\\nReflection: I received an error message. I will now use the sql_db_list_tables tool to find out the name of the table containing the invoices data.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_list_tables\",\\n \"parameters\": {\\n \"tool_input\": \"\"\\n }\\n }\\n]\\n```')]),\n", + " 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'),\n", + " (AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}, log=\"\\nI found out that the table name is 'Invoice'. I will now rewrite my query and run it again.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}}\\n\", message_log=[AIMessage(content='\\nReflection: I found out that the table name is \\'Invoice\\'. I will now rewrite my query and run it again.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC\"\\n }\\n }\\n]\\n```')]),\n", + " 'Error: (sqlite3.OperationalError) no such column: country\\n[SQL: SELECT country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC]\\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'),\n", + " (AgentActionMessageLog(tool='sql_db_schema', tool_input={'table_names': 'Invoice'}, log=\"\\nI received another error message. I will now use the sql_db_schema tool to find out the column names in the 'Invoice' table.\\n{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Invoice'}}\\n\", message_log=[AIMessage(content='\\nReflection: I received another error message. I will now use the sql_db_schema tool to find out the column names in the \\'Invoice\\' table.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_schema\",\\n \"parameters\": {\\n \"table_names\": \"Invoice\"\\n }\\n }\\n]\\n```')]),\n", + " '\\nCREATE TABLE \"Invoice\" (\\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"InvoiceDate\" DATETIME NOT NULL, \\n\\t\"BillingAddress\" NVARCHAR(70), \\n\\t\"BillingCity\" NVARCHAR(40), \\n\\t\"BillingState\" NVARCHAR(40), \\n\\t\"BillingCountry\" NVARCHAR(40), \\n\\t\"BillingPostalCode\" NVARCHAR(10), \\n\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceId\"), \\n\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\n)\\n\\n/*\\n3 rows from Invoice table:\\nInvoiceId\\tCustomerId\\tInvoiceDate\\tBillingAddress\\tBillingCity\\tBillingState\\tBillingCountry\\tBillingPostalCode\\tTotal\\n1\\t2\\t2021-01-01 00:00:00\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t1.98\\n2\\t4\\t2021-01-02 00:00:00\\tUllevålsveien 14\\tOslo\\tNone\\tNorway\\t0171\\t3.96\\n3\\t8\\t2021-01-03 00:00:00\\tGrétrystraat 63\\tBrussels\\tNone\\tBelgium\\t1000\\t5.94\\n*/'),\n", + " (AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT BillingCountry AS country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}, log=\"\\nI found out that the column name for country is 'BillingCountry'. I will now rewrite my query and run it again.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT BillingCountry AS country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC'}}\\n\", message_log=[AIMessage(content='\\nReflection: I found out that the column name for country is \\'BillingCountry\\'. I will now rewrite my query and run it again.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT BillingCountry AS country, COUNT(*) AS invoice_count FROM Invoice GROUP BY country ORDER BY invoice_count DESC\"\\n }\\n }\\n]\\n```')]),\n", + " \"[('USA', 91), ('Canada', 56), ('France', 35), ('Brazil', 35), ('Germany', 28), ('United Kingdom', 21), ('Portugal', 14), ('Czech Republic', 14), ('India', 13), ('Sweden', 7), ('Spain', 7), ('Poland', 7), ('Norway', 7), ('Netherlands', 7), ('Italy', 7), ('Ireland', 7), ('Hungary', 7), ('Finland', 7), ('Denmark', 7), ('Chile', 7), ('Belgium', 7), ('Austria', 7), ('Australia', 7), ('Argentina', 7)]\")]}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\n", + " \"input\": 'which countries have the most invoices?',\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "634500b5-8980-4ee9-93a3-172130feeb5a", + "metadata": {}, + "source": [ + "The agent initially makes some errors as it jumps to answer the question using the db_query tool, but it then realizes it needs to figure out what tables it has access to and what they look like. It then fixes the sql code and is able to generate the right answer" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "95144085-9a57-4b7d-ac8e-c0996c0d3e81", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\n", + "I will search for the customer who has spent the most money and write an answer based on the results.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3mError: (sqlite3.OperationalError) no such table: orders\n", + "[SQL: SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;]\n", + "(Background on this error at: https://sqlalche.me/e/20/e3q8)\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I received an error message. I will now use the sql_db_list_tables tool to find out which tables are available.\n", + "{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\n", + "\u001B[0m\u001B[38;5;200m\u001B[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I found that there is a 'Customer' table and an 'Invoice' table. I will now use the sql_db_schema tool to find out more about these tables.\n", + "{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Customer,Invoice'}}\n", + "\u001B[0m\u001B[33;1m\u001B[1;3m\n", + "CREATE TABLE \"Customer\" (\n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"FirstName\" NVARCHAR(40) NOT NULL, \n", + "\t\"LastName\" NVARCHAR(20) NOT NULL, \n", + "\t\"Company\" NVARCHAR(80), \n", + "\t\"Address\" NVARCHAR(70), \n", + "\t\"City\" NVARCHAR(40), \n", + "\t\"State\" NVARCHAR(40), \n", + "\t\"Country\" NVARCHAR(40), \n", + "\t\"PostalCode\" NVARCHAR(10), \n", + "\t\"Phone\" NVARCHAR(24), \n", + "\t\"Fax\" NVARCHAR(24), \n", + "\t\"Email\" NVARCHAR(60) NOT NULL, \n", + "\t\"SupportRepId\" INTEGER, \n", + "\tPRIMARY KEY (\"CustomerId\"), \n", + "\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Customer table:\n", + "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n", + "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n", + "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n", + "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n", + "*/\n", + "\n", + "\n", + "CREATE TABLE \"Invoice\" (\n", + "\t\"InvoiceId\" INTEGER NOT NULL, \n", + "\t\"CustomerId\" INTEGER NOT NULL, \n", + "\t\"InvoiceDate\" DATETIME NOT NULL, \n", + "\t\"BillingAddress\" NVARCHAR(70), \n", + "\t\"BillingCity\" NVARCHAR(40), \n", + "\t\"BillingState\" NVARCHAR(40), \n", + "\t\"BillingCountry\" NVARCHAR(40), \n", + "\t\"BillingPostalCode\" NVARCHAR(10), \n", + "\t\"Total\" NUMERIC(10, 2) NOT NULL, \n", + "\tPRIMARY KEY (\"InvoiceId\"), \n", + "\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n", + ")\n", + "\n", + "/*\n", + "3 rows from Invoice table:\n", + "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n", + "1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n", + "2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n", + "3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n", + "*/\u001B[0m\u001B[32;1m\u001B[1;3m\n", + "I found that the 'Customer' table contains information about customers, including their names, addresses, and phone numbers. The 'Invoice' table contains information about invoices, including the customer ID, invoice date, and total amount. I will now write a new query to find the customer who has spent the most money.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT c.FirstName, c.LastName, SUM(i.Total) AS total_spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 1;'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3m[('Helena', 'Holý', 49.62)]\u001B[0m\u001B[32;1m\u001B[1;3mRelevant Documents: 1,2,3\n", + "Cited Documents: 3\n", + "Answer: The best customer is Helena Holý, who has spent a total of 49.62.\n", + "Grounded answer: The best customer is Helena Holý, who has spent a total of 49.62.\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'who is the best customer? The customer who has spent the most money is the best.',\n", + " 'output': 'The best customer is Helena Holý, who has spent a total of 49.62.',\n", + " 'citations': [CohereCitation(start=21, end=32, text='Helena Holý', documents=[{'output': \"[('Helena', 'Holý', 49.62)]\"}]),\n", + " CohereCitation(start=59, end=64, text='49.62', documents=[{'output': \"[('Helena', 'Holý', 49.62)]\"}])],\n", + " 'intermediate_steps': [(AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;'}, log=\"\\nI will search for the customer who has spent the most money and write an answer based on the results.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;'}}\\n\", message_log=[AIMessage(content='\\nPlan: I will search for the customer who has spent the most money and write an answer based on the results.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;\"\\n }\\n }\\n]\\n```')]),\n", + " 'Error: (sqlite3.OperationalError) no such table: orders\\n[SQL: SELECT customer_name, SUM(total_price) AS total_spent FROM orders GROUP BY customer_name ORDER BY total_spent DESC LIMIT 1;]\\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'),\n", + " (AgentActionMessageLog(tool='sql_db_list_tables', tool_input={'tool_input': ''}, log=\"\\nI received an error message. I will now use the sql_db_list_tables tool to find out which tables are available.\\n{'tool_name': 'sql_db_list_tables', 'parameters': {'tool_input': ''}}\\n\", message_log=[AIMessage(content='\\nReflection: I received an error message. I will now use the sql_db_list_tables tool to find out which tables are available.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_list_tables\",\\n \"parameters\": {\\n \"tool_input\": \"\"\\n }\\n }\\n]\\n```')]),\n", + " 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'),\n", + " (AgentActionMessageLog(tool='sql_db_schema', tool_input={'table_names': 'Customer,Invoice'}, log=\"\\nI found that there is a 'Customer' table and an 'Invoice' table. I will now use the sql_db_schema tool to find out more about these tables.\\n{'tool_name': 'sql_db_schema', 'parameters': {'table_names': 'Customer,Invoice'}}\\n\", message_log=[AIMessage(content='\\nReflection: I found that there is a \\'Customer\\' table and an \\'Invoice\\' table. I will now use the sql_db_schema tool to find out more about these tables.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_schema\",\\n \"parameters\": {\\n \"table_names\": \"Customer,Invoice\"\\n }\\n }\\n]\\n```')]),\n", + " '\\nCREATE TABLE \"Customer\" (\\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"FirstName\" NVARCHAR(40) NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"Company\" NVARCHAR(80), \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60) NOT NULL, \\n\\t\"SupportRepId\" INTEGER, \\n\\tPRIMARY KEY (\"CustomerId\"), \\n\\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Customer table:\\nCustomerId\\tFirstName\\tLastName\\tCompany\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\tSupportRepId\\n1\\tLuís\\tGonçalves\\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\\tAv. Brigadeiro Faria Lima, 2170\\tSão José dos Campos\\tSP\\tBrazil\\t12227-000\\t+55 (12) 3923-5555\\t+55 (12) 3923-5566\\tluisg@embraer.com.br\\t3\\n2\\tLeonie\\tKöhler\\tNone\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t+49 0711 2842222\\tNone\\tleonekohler@surfeu.de\\t5\\n3\\tFrançois\\tTremblay\\tNone\\t1498 rue Bélanger\\tMontréal\\tQC\\tCanada\\tH2G 1A7\\t+1 (514) 721-4711\\tNone\\tftremblay@gmail.com\\t3\\n*/\\n\\n\\nCREATE TABLE \"Invoice\" (\\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"InvoiceDate\" DATETIME NOT NULL, \\n\\t\"BillingAddress\" NVARCHAR(70), \\n\\t\"BillingCity\" NVARCHAR(40), \\n\\t\"BillingState\" NVARCHAR(40), \\n\\t\"BillingCountry\" NVARCHAR(40), \\n\\t\"BillingPostalCode\" NVARCHAR(10), \\n\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceId\"), \\n\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\n)\\n\\n/*\\n3 rows from Invoice table:\\nInvoiceId\\tCustomerId\\tInvoiceDate\\tBillingAddress\\tBillingCity\\tBillingState\\tBillingCountry\\tBillingPostalCode\\tTotal\\n1\\t2\\t2021-01-01 00:00:00\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t1.98\\n2\\t4\\t2021-01-02 00:00:00\\tUllevålsveien 14\\tOslo\\tNone\\tNorway\\t0171\\t3.96\\n3\\t8\\t2021-01-03 00:00:00\\tGrétrystraat 63\\tBrussels\\tNone\\tBelgium\\t1000\\t5.94\\n*/'),\n", + " (AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT c.FirstName, c.LastName, SUM(i.Total) AS total_spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 1;'}, log=\"\\nI found that the 'Customer' table contains information about customers, including their names, addresses, and phone numbers. The 'Invoice' table contains information about invoices, including the customer ID, invoice date, and total amount. I will now write a new query to find the customer who has spent the most money.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT c.FirstName, c.LastName, SUM(i.Total) AS total_spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 1;'}}\\n\", message_log=[AIMessage(content='\\nReflection: I found that the \\'Customer\\' table contains information about customers, including their names, addresses, and phone numbers. The \\'Invoice\\' table contains information about invoices, including the customer ID, invoice date, and total amount. I will now write a new query to find the customer who has spent the most money.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT c.FirstName, c.LastName, SUM(i.Total) AS total_spent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId ORDER BY total_spent DESC LIMIT 1;\"\\n }\\n }\\n]\\n```')]),\n", + " \"[('Helena', 'Holý', 49.62)]\")]}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\n", + " \"input\": 'who is the best customer? The customer who has spent the most money is the best.',\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "14ca1ab1-1deb-4c25-bc2d-6f047bcb49d9", + "metadata": {}, + "source": [ + "As you can see, the agent makes an error, but is able to rectify itself. It also manages to generate SQL query over two tables in the database" + ] + }, + { + "cell_type": "markdown", + "id": "1c9ea4d9-8f92-4780-b33e-e348b2fa854f", + "metadata": {}, + "source": [ + "\n", + "# SQL Agent with context" + ] + }, + { + "cell_type": "markdown", + "id": "1883f0f6-c1b9-41d0-9049-68d2101f8a9e", + "metadata": {}, + "source": [ + "From our experiments, we have found that passing in additional context to the preamble can help reduce the initial failures. This context is provided by the SQLDBToolkit and contains the first 3 rows of the tables in the Database" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "fd781ff2-ee5f-4fef-b1b4-fde58562383c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'table_info': '\\nCREATE TABLE \"Album\" (\\n\\t\"AlbumId\" INTEGER NOT NULL, \\n\\t\"Title\" NVARCHAR(160) NOT NULL, \\n\\t\"ArtistId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"AlbumId\"), \\n\\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\\n)\\n\\n/*\\n3 rows from Album table:\\nAlbumId\\tTitle\\tArtistId\\n1\\tFor Those About To Rock We Salute You\\t1\\n2\\tBalls to the Wall\\t2\\n3\\tRestless and Wild\\t2\\n*/\\n\\n\\nCREATE TABLE \"Artist\" (\\n\\t\"ArtistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"ArtistId\")\\n)\\n\\n/*\\n3 rows from Artist table:\\nArtistId\\tName\\n1\\tAC/DC\\n2\\tAccept\\n3\\tAerosmith\\n*/\\n\\n\\nCREATE TABLE \"Customer\" (\\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"FirstName\" NVARCHAR(40) NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"Company\" NVARCHAR(80), \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60) NOT NULL, \\n\\t\"SupportRepId\" INTEGER, \\n\\tPRIMARY KEY (\"CustomerId\"), \\n\\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Customer table:\\nCustomerId\\tFirstName\\tLastName\\tCompany\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\tSupportRepId\\n1\\tLuís\\tGonçalves\\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\\tAv. Brigadeiro Faria Lima, 2170\\tSão José dos Campos\\tSP\\tBrazil\\t12227-000\\t+55 (12) 3923-5555\\t+55 (12) 3923-5566\\tluisg@embraer.com.br\\t3\\n2\\tLeonie\\tKöhler\\tNone\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t+49 0711 2842222\\tNone\\tleonekohler@surfeu.de\\t5\\n3\\tFrançois\\tTremblay\\tNone\\t1498 rue Bélanger\\tMontréal\\tQC\\tCanada\\tH2G 1A7\\t+1 (514) 721-4711\\tNone\\tftremblay@gmail.com\\t3\\n*/\\n\\n\\nCREATE TABLE \"Employee\" (\\n\\t\"EmployeeId\" INTEGER NOT NULL, \\n\\t\"LastName\" NVARCHAR(20) NOT NULL, \\n\\t\"FirstName\" NVARCHAR(20) NOT NULL, \\n\\t\"Title\" NVARCHAR(30), \\n\\t\"ReportsTo\" INTEGER, \\n\\t\"BirthDate\" DATETIME, \\n\\t\"HireDate\" DATETIME, \\n\\t\"Address\" NVARCHAR(70), \\n\\t\"City\" NVARCHAR(40), \\n\\t\"State\" NVARCHAR(40), \\n\\t\"Country\" NVARCHAR(40), \\n\\t\"PostalCode\" NVARCHAR(10), \\n\\t\"Phone\" NVARCHAR(24), \\n\\t\"Fax\" NVARCHAR(24), \\n\\t\"Email\" NVARCHAR(60), \\n\\tPRIMARY KEY (\"EmployeeId\"), \\n\\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\\n)\\n\\n/*\\n3 rows from Employee table:\\nEmployeeId\\tLastName\\tFirstName\\tTitle\\tReportsTo\\tBirthDate\\tHireDate\\tAddress\\tCity\\tState\\tCountry\\tPostalCode\\tPhone\\tFax\\tEmail\\n1\\tAdams\\tAndrew\\tGeneral Manager\\tNone\\t1962-02-18 00:00:00\\t2002-08-14 00:00:00\\t11120 Jasper Ave NW\\tEdmonton\\tAB\\tCanada\\tT5K 2N1\\t+1 (780) 428-9482\\t+1 (780) 428-3457\\tandrew@chinookcorp.com\\n2\\tEdwards\\tNancy\\tSales Manager\\t1\\t1958-12-08 00:00:00\\t2002-05-01 00:00:00\\t825 8 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 2T3\\t+1 (403) 262-3443\\t+1 (403) 262-3322\\tnancy@chinookcorp.com\\n3\\tPeacock\\tJane\\tSales Support Agent\\t2\\t1973-08-29 00:00:00\\t2002-04-01 00:00:00\\t1111 6 Ave SW\\tCalgary\\tAB\\tCanada\\tT2P 5M5\\t+1 (403) 262-3443\\t+1 (403) 262-6712\\tjane@chinookcorp.com\\n*/\\n\\n\\nCREATE TABLE \"Genre\" (\\n\\t\"GenreId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"GenreId\")\\n)\\n\\n/*\\n3 rows from Genre table:\\nGenreId\\tName\\n1\\tRock\\n2\\tJazz\\n3\\tMetal\\n*/\\n\\n\\nCREATE TABLE \"Invoice\" (\\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"CustomerId\" INTEGER NOT NULL, \\n\\t\"InvoiceDate\" DATETIME NOT NULL, \\n\\t\"BillingAddress\" NVARCHAR(70), \\n\\t\"BillingCity\" NVARCHAR(40), \\n\\t\"BillingState\" NVARCHAR(40), \\n\\t\"BillingCountry\" NVARCHAR(40), \\n\\t\"BillingPostalCode\" NVARCHAR(10), \\n\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceId\"), \\n\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\n)\\n\\n/*\\n3 rows from Invoice table:\\nInvoiceId\\tCustomerId\\tInvoiceDate\\tBillingAddress\\tBillingCity\\tBillingState\\tBillingCountry\\tBillingPostalCode\\tTotal\\n1\\t2\\t2021-01-01 00:00:00\\tTheodor-Heuss-Straße 34\\tStuttgart\\tNone\\tGermany\\t70174\\t1.98\\n2\\t4\\t2021-01-02 00:00:00\\tUllevålsveien 14\\tOslo\\tNone\\tNorway\\t0171\\t3.96\\n3\\t8\\t2021-01-03 00:00:00\\tGrétrystraat 63\\tBrussels\\tNone\\tBelgium\\t1000\\t5.94\\n*/\\n\\n\\nCREATE TABLE \"InvoiceLine\" (\\n\\t\"InvoiceLineId\" INTEGER NOT NULL, \\n\\t\"InvoiceId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\n\\t\"Quantity\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"InvoiceLineId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\\n)\\n\\n/*\\n3 rows from InvoiceLine table:\\nInvoiceLineId\\tInvoiceId\\tTrackId\\tUnitPrice\\tQuantity\\n1\\t1\\t2\\t0.99\\t1\\n2\\t1\\t4\\t0.99\\t1\\n3\\t2\\t6\\t0.99\\t1\\n*/\\n\\n\\nCREATE TABLE \"MediaType\" (\\n\\t\"MediaTypeId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"MediaTypeId\")\\n)\\n\\n/*\\n3 rows from MediaType table:\\nMediaTypeId\\tName\\n1\\tMPEG audio file\\n2\\tProtected AAC audio file\\n3\\tProtected MPEG-4 video file\\n*/\\n\\n\\nCREATE TABLE \"Playlist\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(120), \\n\\tPRIMARY KEY (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from Playlist table:\\nPlaylistId\\tName\\n1\\tMusic\\n2\\tMovies\\n3\\tTV Shows\\n*/\\n\\n\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n\\n/*\\n3 rows from PlaylistTrack table:\\nPlaylistId\\tTrackId\\n1\\t3402\\n1\\t3389\\n1\\t3390\\n*/\\n\\n\\nCREATE TABLE \"Track\" (\\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\t\"Name\" NVARCHAR(200) NOT NULL, \\n\\t\"AlbumId\" INTEGER, \\n\\t\"MediaTypeId\" INTEGER NOT NULL, \\n\\t\"GenreId\" INTEGER, \\n\\t\"Composer\" NVARCHAR(220), \\n\\t\"Milliseconds\" INTEGER NOT NULL, \\n\\t\"Bytes\" INTEGER, \\n\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\n\\tPRIMARY KEY (\"TrackId\"), \\n\\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \\n\\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \\n\\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\\n)\\n\\n/*\\n3 rows from Track table:\\nTrackId\\tName\\tAlbumId\\tMediaTypeId\\tGenreId\\tComposer\\tMilliseconds\\tBytes\\tUnitPrice\\n1\\tFor Those About To Rock (We Salute You)\\t1\\t1\\t1\\tAngus Young, Malcolm Young, Brian Johnson\\t343719\\t11170334\\t0.99\\n2\\tBalls to the Wall\\t2\\t2\\t1\\tU. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann\\t342562\\t5510424\\t0.99\\n3\\tFast As a Shark\\t3\\t2\\t1\\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\\t230619\\t3990994\\t0.99\\n*/',\n", + " 'table_names': 'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track'}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "context" + ] + }, + { + "cell_type": "markdown", + "id": "d726ce1f-1b47-4b06-bd2e-8969c429dde9", + "metadata": {}, + "source": [ + "We can pass this context into the preamble and re-run a query to see how it performs." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "06913698-60ee-4b10-b28b-6462c565ad63", + "metadata": {}, + "outputs": [], + "source": [ + "preamble=\"\"\"## Task And Context\n", + "You use your advanced complex reasoning capabilities to help people by answering their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You may need to use multiple tools in parallel or sequentially to complete your task. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n", + "\n", + "## Style Guide\n", + "Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.\n", + "\n", + "## Additional Information\n", + "You are an expert who answers the user's question by creating SQL queries and executing them.\n", + "You are equipped with a number of relevant SQL tools.\n", + "\n", + "Here is information about the database:\n", + "{schema_info}\n", + "\"\"\".format(schema_info=context)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ad461d18-6ae9-40b5-8c85-a48ab9aa6bc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n", + "\u001B[32;1m\u001B[1;3m\n", + "I will write a SQL query to find the customer who has spent the most money.\n", + "{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSpent\\nFROM Customer c\\nJOIN Invoice i ON c.CustomerId = i.CustomerId\\nGROUP BY c.CustomerId\\nORDER BY TotalSpent DESC\\nLIMIT 1;'}}\n", + "\u001B[0m\u001B[36;1m\u001B[1;3m[(6, 'Helena', 'Holý', 49.62)]\u001B[0m\u001B[32;1m\u001B[1;3mRelevant Documents: 0\n", + "Cited Documents: 0\n", + "Answer: The customer who has spent the most money is Helena Holý.\n", + "Grounded answer: The customer who has spent the most money is Helena Holý.\u001B[0m\n", + "\n", + "\u001B[1m> Finished chain.\u001B[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'provide the name of the best customer? The customer who has spent the most money is the best.',\n", + " 'preamble': '## Task And Context\\nYou use your advanced complex reasoning capabilities to help people by answering their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You may need to use multiple tools in parallel or sequentially to complete your task. You should focus on serving the user\\'s needs as best you can, which will be wide-ranging.\\n\\n## Style Guide\\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.\\n\\n## Additional Information\\nYou are an expert who answers the user\\'s question by creating SQL queries and executing them.\\nYou are equipped with a number of relevant SQL tools.\\n\\nHere is information about the database:\\n{\\'table_info\\': \\'\\\\nCREATE TABLE \"Album\" (\\\\n\\\\t\"AlbumId\" INTEGER NOT NULL, \\\\n\\\\t\"Title\" NVARCHAR(160) NOT NULL, \\\\n\\\\t\"ArtistId\" INTEGER NOT NULL, \\\\n\\\\tPRIMARY KEY (\"AlbumId\"), \\\\n\\\\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Album table:\\\\nAlbumId\\\\tTitle\\\\tArtistId\\\\n1\\\\tFor Those About To Rock We Salute You\\\\t1\\\\n2\\\\tBalls to the Wall\\\\t2\\\\n3\\\\tRestless and Wild\\\\t2\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Artist\" (\\\\n\\\\t\"ArtistId\" INTEGER NOT NULL, \\\\n\\\\t\"Name\" NVARCHAR(120), \\\\n\\\\tPRIMARY KEY (\"ArtistId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Artist table:\\\\nArtistId\\\\tName\\\\n1\\\\tAC/DC\\\\n2\\\\tAccept\\\\n3\\\\tAerosmith\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Customer\" (\\\\n\\\\t\"CustomerId\" INTEGER NOT NULL, \\\\n\\\\t\"FirstName\" NVARCHAR(40) NOT NULL, \\\\n\\\\t\"LastName\" NVARCHAR(20) NOT NULL, \\\\n\\\\t\"Company\" NVARCHAR(80), \\\\n\\\\t\"Address\" NVARCHAR(70), \\\\n\\\\t\"City\" NVARCHAR(40), \\\\n\\\\t\"State\" NVARCHAR(40), \\\\n\\\\t\"Country\" NVARCHAR(40), \\\\n\\\\t\"PostalCode\" NVARCHAR(10), \\\\n\\\\t\"Phone\" NVARCHAR(24), \\\\n\\\\t\"Fax\" NVARCHAR(24), \\\\n\\\\t\"Email\" NVARCHAR(60) NOT NULL, \\\\n\\\\t\"SupportRepId\" INTEGER, \\\\n\\\\tPRIMARY KEY (\"CustomerId\"), \\\\n\\\\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Customer table:\\\\nCustomerId\\\\tFirstName\\\\tLastName\\\\tCompany\\\\tAddress\\\\tCity\\\\tState\\\\tCountry\\\\tPostalCode\\\\tPhone\\\\tFax\\\\tEmail\\\\tSupportRepId\\\\n1\\\\tLuís\\\\tGonçalves\\\\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\\\\tAv. Brigadeiro Faria Lima, 2170\\\\tSão José dos Campos\\\\tSP\\\\tBrazil\\\\t12227-000\\\\t+55 (12) 3923-5555\\\\t+55 (12) 3923-5566\\\\tluisg@embraer.com.br\\\\t3\\\\n2\\\\tLeonie\\\\tKöhler\\\\tNone\\\\tTheodor-Heuss-Straße 34\\\\tStuttgart\\\\tNone\\\\tGermany\\\\t70174\\\\t+49 0711 2842222\\\\tNone\\\\tleonekohler@surfeu.de\\\\t5\\\\n3\\\\tFrançois\\\\tTremblay\\\\tNone\\\\t1498 rue Bélanger\\\\tMontréal\\\\tQC\\\\tCanada\\\\tH2G 1A7\\\\t+1 (514) 721-4711\\\\tNone\\\\tftremblay@gmail.com\\\\t3\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Employee\" (\\\\n\\\\t\"EmployeeId\" INTEGER NOT NULL, \\\\n\\\\t\"LastName\" NVARCHAR(20) NOT NULL, \\\\n\\\\t\"FirstName\" NVARCHAR(20) NOT NULL, \\\\n\\\\t\"Title\" NVARCHAR(30), \\\\n\\\\t\"ReportsTo\" INTEGER, \\\\n\\\\t\"BirthDate\" DATETIME, \\\\n\\\\t\"HireDate\" DATETIME, \\\\n\\\\t\"Address\" NVARCHAR(70), \\\\n\\\\t\"City\" NVARCHAR(40), \\\\n\\\\t\"State\" NVARCHAR(40), \\\\n\\\\t\"Country\" NVARCHAR(40), \\\\n\\\\t\"PostalCode\" NVARCHAR(10), \\\\n\\\\t\"Phone\" NVARCHAR(24), \\\\n\\\\t\"Fax\" NVARCHAR(24), \\\\n\\\\t\"Email\" NVARCHAR(60), \\\\n\\\\tPRIMARY KEY (\"EmployeeId\"), \\\\n\\\\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Employee table:\\\\nEmployeeId\\\\tLastName\\\\tFirstName\\\\tTitle\\\\tReportsTo\\\\tBirthDate\\\\tHireDate\\\\tAddress\\\\tCity\\\\tState\\\\tCountry\\\\tPostalCode\\\\tPhone\\\\tFax\\\\tEmail\\\\n1\\\\tAdams\\\\tAndrew\\\\tGeneral Manager\\\\tNone\\\\t1962-02-18 00:00:00\\\\t2002-08-14 00:00:00\\\\t11120 Jasper Ave NW\\\\tEdmonton\\\\tAB\\\\tCanada\\\\tT5K 2N1\\\\t+1 (780) 428-9482\\\\t+1 (780) 428-3457\\\\tandrew@chinookcorp.com\\\\n2\\\\tEdwards\\\\tNancy\\\\tSales Manager\\\\t1\\\\t1958-12-08 00:00:00\\\\t2002-05-01 00:00:00\\\\t825 8 Ave SW\\\\tCalgary\\\\tAB\\\\tCanada\\\\tT2P 2T3\\\\t+1 (403) 262-3443\\\\t+1 (403) 262-3322\\\\tnancy@chinookcorp.com\\\\n3\\\\tPeacock\\\\tJane\\\\tSales Support Agent\\\\t2\\\\t1973-08-29 00:00:00\\\\t2002-04-01 00:00:00\\\\t1111 6 Ave SW\\\\tCalgary\\\\tAB\\\\tCanada\\\\tT2P 5M5\\\\t+1 (403) 262-3443\\\\t+1 (403) 262-6712\\\\tjane@chinookcorp.com\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Genre\" (\\\\n\\\\t\"GenreId\" INTEGER NOT NULL, \\\\n\\\\t\"Name\" NVARCHAR(120), \\\\n\\\\tPRIMARY KEY (\"GenreId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Genre table:\\\\nGenreId\\\\tName\\\\n1\\\\tRock\\\\n2\\\\tJazz\\\\n3\\\\tMetal\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Invoice\" (\\\\n\\\\t\"InvoiceId\" INTEGER NOT NULL, \\\\n\\\\t\"CustomerId\" INTEGER NOT NULL, \\\\n\\\\t\"InvoiceDate\" DATETIME NOT NULL, \\\\n\\\\t\"BillingAddress\" NVARCHAR(70), \\\\n\\\\t\"BillingCity\" NVARCHAR(40), \\\\n\\\\t\"BillingState\" NVARCHAR(40), \\\\n\\\\t\"BillingCountry\" NVARCHAR(40), \\\\n\\\\t\"BillingPostalCode\" NVARCHAR(10), \\\\n\\\\t\"Total\" NUMERIC(10, 2) NOT NULL, \\\\n\\\\tPRIMARY KEY (\"InvoiceId\"), \\\\n\\\\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Invoice table:\\\\nInvoiceId\\\\tCustomerId\\\\tInvoiceDate\\\\tBillingAddress\\\\tBillingCity\\\\tBillingState\\\\tBillingCountry\\\\tBillingPostalCode\\\\tTotal\\\\n1\\\\t2\\\\t2021-01-01 00:00:00\\\\tTheodor-Heuss-Straße 34\\\\tStuttgart\\\\tNone\\\\tGermany\\\\t70174\\\\t1.98\\\\n2\\\\t4\\\\t2021-01-02 00:00:00\\\\tUllevålsveien 14\\\\tOslo\\\\tNone\\\\tNorway\\\\t0171\\\\t3.96\\\\n3\\\\t8\\\\t2021-01-03 00:00:00\\\\tGrétrystraat 63\\\\tBrussels\\\\tNone\\\\tBelgium\\\\t1000\\\\t5.94\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"InvoiceLine\" (\\\\n\\\\t\"InvoiceLineId\" INTEGER NOT NULL, \\\\n\\\\t\"InvoiceId\" INTEGER NOT NULL, \\\\n\\\\t\"TrackId\" INTEGER NOT NULL, \\\\n\\\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\\\n\\\\t\"Quantity\" INTEGER NOT NULL, \\\\n\\\\tPRIMARY KEY (\"InvoiceLineId\"), \\\\n\\\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\\\n\\\\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from InvoiceLine table:\\\\nInvoiceLineId\\\\tInvoiceId\\\\tTrackId\\\\tUnitPrice\\\\tQuantity\\\\n1\\\\t1\\\\t2\\\\t0.99\\\\t1\\\\n2\\\\t1\\\\t4\\\\t0.99\\\\t1\\\\n3\\\\t2\\\\t6\\\\t0.99\\\\t1\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"MediaType\" (\\\\n\\\\t\"MediaTypeId\" INTEGER NOT NULL, \\\\n\\\\t\"Name\" NVARCHAR(120), \\\\n\\\\tPRIMARY KEY (\"MediaTypeId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from MediaType table:\\\\nMediaTypeId\\\\tName\\\\n1\\\\tMPEG audio file\\\\n2\\\\tProtected AAC audio file\\\\n3\\\\tProtected MPEG-4 video file\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Playlist\" (\\\\n\\\\t\"PlaylistId\" INTEGER NOT NULL, \\\\n\\\\t\"Name\" NVARCHAR(120), \\\\n\\\\tPRIMARY KEY (\"PlaylistId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Playlist table:\\\\nPlaylistId\\\\tName\\\\n1\\\\tMusic\\\\n2\\\\tMovies\\\\n3\\\\tTV Shows\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"PlaylistTrack\" (\\\\n\\\\t\"PlaylistId\" INTEGER NOT NULL, \\\\n\\\\t\"TrackId\" INTEGER NOT NULL, \\\\n\\\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\\\n\\\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\\\n\\\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from PlaylistTrack table:\\\\nPlaylistId\\\\tTrackId\\\\n1\\\\t3402\\\\n1\\\\t3389\\\\n1\\\\t3390\\\\n*/\\\\n\\\\n\\\\nCREATE TABLE \"Track\" (\\\\n\\\\t\"TrackId\" INTEGER NOT NULL, \\\\n\\\\t\"Name\" NVARCHAR(200) NOT NULL, \\\\n\\\\t\"AlbumId\" INTEGER, \\\\n\\\\t\"MediaTypeId\" INTEGER NOT NULL, \\\\n\\\\t\"GenreId\" INTEGER, \\\\n\\\\t\"Composer\" NVARCHAR(220), \\\\n\\\\t\"Milliseconds\" INTEGER NOT NULL, \\\\n\\\\t\"Bytes\" INTEGER, \\\\n\\\\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \\\\n\\\\tPRIMARY KEY (\"TrackId\"), \\\\n\\\\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \\\\n\\\\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \\\\n\\\\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\\\\n)\\\\n\\\\n/*\\\\n3 rows from Track table:\\\\nTrackId\\\\tName\\\\tAlbumId\\\\tMediaTypeId\\\\tGenreId\\\\tComposer\\\\tMilliseconds\\\\tBytes\\\\tUnitPrice\\\\n1\\\\tFor Those About To Rock (We Salute You)\\\\t1\\\\t1\\\\t1\\\\tAngus Young, Malcolm Young, Brian Johnson\\\\t343719\\\\t11170334\\\\t0.99\\\\n2\\\\tBalls to the Wall\\\\t2\\\\t2\\\\t1\\\\tU. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann\\\\t342562\\\\t5510424\\\\t0.99\\\\n3\\\\tFast As a Shark\\\\t3\\\\t2\\\\t1\\\\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\\\\t230619\\\\t3990994\\\\t0.99\\\\n*/\\', \\'table_names\\': \\'Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\\'}\\n',\n", + " 'output': 'The customer who has spent the most money is Helena Holý.',\n", + " 'citations': [CohereCitation(start=45, end=56, text='Helena Holý', documents=[{'output': \"[(6, 'Helena', 'Holý', 49.62)]\"}])],\n", + " 'intermediate_steps': [(AgentActionMessageLog(tool='sql_db_query', tool_input={'query': 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSpent\\nFROM Customer c\\nJOIN Invoice i ON c.CustomerId = i.CustomerId\\nGROUP BY c.CustomerId\\nORDER BY TotalSpent DESC\\nLIMIT 1;'}, log=\"\\nI will write a SQL query to find the customer who has spent the most money.\\n{'tool_name': 'sql_db_query', 'parameters': {'query': 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSpent\\\\nFROM Customer c\\\\nJOIN Invoice i ON c.CustomerId = i.CustomerId\\\\nGROUP BY c.CustomerId\\\\nORDER BY TotalSpent DESC\\\\nLIMIT 1;'}}\\n\", message_log=[AIMessage(content='\\nPlan: I will write a SQL query to find the customer who has spent the most money.\\nAction: ```json\\n[\\n {\\n \"tool_name\": \"sql_db_query\",\\n \"parameters\": {\\n \"query\": \"SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSpent\\\\nFROM Customer c\\\\nJOIN Invoice i ON c.CustomerId = i.CustomerId\\\\nGROUP BY c.CustomerId\\\\nORDER BY TotalSpent DESC\\\\nLIMIT 1;\"\\n }\\n }\\n]\\n```')]),\n", + " \"[(6, 'Helena', 'Holý', 49.62)]\")]}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.invoke({\n", + " \"input\": 'provide the name of the best customer? The customer who has spent the most money is the best.',\n", + " \"preamble\": preamble\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "b1cfcaaa-b1c0-4ded-baaa-7bbdb097be6e", + "metadata": {}, + "source": [ + "Great, we can see that passing that additional context actually avoids the error seen in the previous section and gets to the answer in one tool call. Of course this works as long as you have a few tables and a few columns per table. We will follow up with more techniques to improve stability and scalability in another notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95a1feed-1480-497d-bacf-9ab2667b573c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}