diff --git a/notebooks/agents/agentic-RAG/agentic_multi_stage_rag_native.ipynb b/notebooks/agents/agentic-RAG/agentic_multi_stage_rag_native.ipynb
new file mode 100644
index 00000000..a83954dd
--- /dev/null
+++ b/notebooks/agents/agentic-RAG/agentic_multi_stage_rag_native.ipynb
@@ -0,0 +1,758 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "# Agentic Multi-Stage RAG with Cohere Tools API \n",
+ "\n",
+ "## Motivation \n",
+ "\n",
+ "Retrieval augmented generation (RAG) has been a go-to use case that enterprises have been adopting with large language models (LLMs). Even though it works well in general, there are edge cases where this can fail. Most commonly, when the retrieved document mentions the query but actually refers to another document, the model will fail to generate the correct answer. \n",
+ "\n",
+ "We propose an agentic RAG system that leverages tool use to continue to retrieve documents if correct ones were not retrieved at first try. This is ideal for use cases where accuracy is a top priority and latency is not. For example, lawyers trying to find the most accurate answer from their contracts are willing to wait a few more seconds to get the answer instead of getting wrong answers fast. \n",
+ "\n",
+ "\n",
+ "## Objective\n",
+ "\n",
+ "This notebook, we will explore how we can build a simple agentic RAG using Cohere's native API. We have prepared a fake dataset to demonstrate the use case. \n",
+ "We ask three questions that require different depths of retrieval. We will see how the model answers the question between simple and agentic RAG. \n",
+ "\n",
+ "\n",
+ "## Disclaimer \n",
+ "\n",
+ "One of the challenges in building a RAG system is that it has many moving pieces: vector database, type of embedding model, use of reranker, number of retrieved documents, chunking strategy, and more. These components can make debugging and evaluating RAG systems difficult. Since this notebook focuses on the concept of agentic RAG, it will simplify other parts of the RAG system. For example, we will only retrieve top 1 docuemnt to demonstrate what happens when retrieved document does not contain the answer needed. \n",
+ "\n",
+ "\n",
+ "## Result \n",
+ "\n",
+ "|Type | Question | Simple Rag | Agentic Rag | \n",
+ "|---|--------------|---------|---------|\n",
+ "|Single-stage retrieval| Is there a state level law for wearing helmets? | There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use. | There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles. | \n",
+ "|Multi-stage retrieval|I live in orting, do I need to wear a helmet with a bike?|In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.|Yes, you do need to wear a helmet with a bike in Orting if you are under 17.|\n",
+ "\n",
+ "As you will see more below, the multi-stage retrieval is achieved by adding a new function `reference_extractor()` that extracts other references in the documents and updating the instruction so the agent continues to retrieve more documents. \n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import cohere\n",
+ "import pandas as pd\n",
+ "from sklearn.metrics.pairwise import cosine_similarity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 161,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "cohere version: 5.5.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "# versions\n",
+ "print('cohere version:', cohere.__version__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "COHERE_API_KEY = os.environ.get(\"CO_API_KEY\")\n",
+ "COHERE_MODEL = 'command-r-plus'\n",
+ "co = cohere.Client(api_key=COHERE_API_KEY)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Data \n",
+ "\n",
+ "We leveraged data from [Washington Department of Transportation](https://wsdot.wa.gov/travel/bicycling-walking/bicycling-washington/bicyclist-laws-safety) and modified to fit the need of this demo. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 279,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "documents = [\n",
+ " {\n",
+ " \"title\": \"Bicycle law\",\n",
+ " \"body\": \"\"\"\n",
+ " Traffic Infractions and fees - For all information related to bicycle traffic infractions such as not wearing a helmet and fee information, please visit Section 3b for more information.\n",
+ " Riding on the road - When riding on a roadway, a cyclist has all the rights and responsibilities of a vehicle driver (RCW 46.61.755). Bicyclists who violate traffic laws may be ticketed (RCW 46.61.750).\n",
+ " Roads closed to bicyclists - Some designated sections of the state's limited access highway system may be closed to bicyclists. See the permanent bike restrictions map for more information. In addition, local governments may adopt ordinances banning cycling on specific roads or on sidewalks within business districts.\n",
+ " Children bicycling - Parents or guardians may not knowingly permit bicycle traffic violations by their ward (RCW 46.61.700).\n",
+ " Riding side by side - Bicyclists may ride side by side, but not more than two abreast (RCW 46.61.770).\n",
+ " Riding at night - For night bicycle riding, a white front light (not a reflector) visible for 500 feet and a red rear reflector are required. A red rear light may be used in addition to the required reflector (RCW 46.61.780).\n",
+ " Shoulder vs. bike lane - Bicyclists may choose to ride on the path, bike lane, shoulder or travel lane as suits their safety needs (RCW 46.61.770).\n",
+ " Bicycle helmets - Currently, there is no state law requiring helmet use. However, some cities and counties do require helmets. For specific information along with location for bicycle helmet law please reference to section 21a.\n",
+ " Bicycle equipment - Bicycles must be equipped with a white front light visible for 500 feet and a red rear reflector (RCW 46.61.780). A red rear light may be used in addition to the required reflector.\n",
+ "\"\"\",\n",
+ " },\n",
+ " {\n",
+ " \"title\": \"Bicycle helmet requirement\",\n",
+ " \"body\": \"Currently, there is no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles. Here is a list of those locations and when the laws were enacted. For specific information along with location for bicycle helmet law please reference to section 21a.\",\n",
+ " },\n",
+ " {\n",
+ " \"title\": \"Section 21a\",\n",
+ " \"body\": \"\"\"helmet rules by location: These are city and county level rules. The following group must wear helmets.\n",
+ " Location name | Who is affected | Effective date\n",
+ " Aberdeen | All ages | 2001\n",
+ " Bainbridge Island | All ages | 2001\n",
+ " Bellevue | All ages | 2001\n",
+ " Bremerton | All ages | 2000\n",
+ " DuPont | All ages | 2008\n",
+ " Eatonville | All ages | 1996\n",
+ " Fircrest | All ages | 1995\n",
+ " Gig Harbor | All ages | 1996\n",
+ " Kent | All ages | 1999\n",
+ " Lynnwood | All ages | 2004\n",
+ " Lakewood | All ages | 1996\n",
+ " Milton | All ages | 1997\n",
+ " Orting | Under 17 | 1997\n",
+ "\n",
+ " For fines and rules, you will be charged in according with Section 3b of the law.\n",
+ " \"\"\",\n",
+ " },\n",
+ " {\n",
+ " \"title\": \"Section 3b\",\n",
+ " \"body\": \"\"\"Traffic infraction - A person operating a bicycle upon a roadway or highway shall be subject to the provisions of this chapter relating to traffic infractions.\n",
+ " 1. Stop for people in crosswalks. Every intersection is a crosswalk - It’s the law. Drivers must stop for pedestrians at intersections, whether it’s an unmarked or marked crosswalk, and bicyclists in crosswalks are considered pedestrians. Also, it is illegal to pass another vehicle stopped for someone at a crosswalk. In Washington, the leading action motorists take that results in them hitting someone is a failure to yield to pedestrians.\n",
+ " 2. Put the phone down. Hand-held cell phone use and texting is prohibited for all Washington drivers and may result in a $136 fine for first offense, $235 on the second distracted-driving citation.\n",
+ " 3. Helmets are required for all bicyclists according to the state and municipal laws. If you are in a group required to wear a helmet but do not wear it you can be fined $48. # If you are the parent or legal guardian of a child under 17 and knowingly allow them to ride without a helmet, you can be fined $136.\n",
+ "\"\"\",\n",
+ " },\n",
+ "]\n",
+ "db = pd.DataFrame(documents)\n",
+ "# comebine title and body\n",
+ "db[\"combined\"] = \"Title: \" + db[\"title\"] + \"\\n\" + \"Body: \" + db[\"body\"]\n",
+ "# generate embedding\n",
+ "embeddings = co.embed(\n",
+ " texts=db.combined.tolist(), model=\"embed-english-v3.0\", input_type=\"search_document\"\n",
+ ")\n",
+ "db[\"embeddings\"] = embeddings.embeddings\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 329,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " body | \n",
+ " combined | \n",
+ " embeddings | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Bicycle law | \n",
+ " \\n Traffic Infractions and fees - For a... | \n",
+ " Title: Bicycle law\\nBody: \\n Traffic In... | \n",
+ " [-0.024673462, -0.034729004, 0.0418396, 0.0121... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Bicycle helmet requirement | \n",
+ " Currently, there is no state law requiring hel... | \n",
+ " Title: Bicycle helmet requirement\\nBody: Curre... | \n",
+ " [-0.019180298, -0.037384033, 0.0027389526, -0.... | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Section 21a | \n",
+ " helmet rules by location: These are city and c... | \n",
+ " Title: Section 21a\\nBody: helmet rules by loca... | \n",
+ " [0.031097412, 0.0007619858, -0.023010254, -0.0... | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Section 3b | \n",
+ " Traffic infraction - A person operating a bicy... | \n",
+ " Title: Section 3b\\nBody: Traffic infraction - ... | \n",
+ " [0.015602112, -0.016143799, 0.032958984, 0.000... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " title \\\n",
+ "0 Bicycle law \n",
+ "1 Bicycle helmet requirement \n",
+ "2 Section 21a \n",
+ "3 Section 3b \n",
+ "\n",
+ " body \\\n",
+ "0 \\n Traffic Infractions and fees - For a... \n",
+ "1 Currently, there is no state law requiring hel... \n",
+ "2 helmet rules by location: These are city and c... \n",
+ "3 Traffic infraction - A person operating a bicy... \n",
+ "\n",
+ " combined \\\n",
+ "0 Title: Bicycle law\\nBody: \\n Traffic In... \n",
+ "1 Title: Bicycle helmet requirement\\nBody: Curre... \n",
+ "2 Title: Section 21a\\nBody: helmet rules by loca... \n",
+ "3 Title: Section 3b\\nBody: Traffic infraction - ... \n",
+ "\n",
+ " embeddings \n",
+ "0 [-0.024673462, -0.034729004, 0.0418396, 0.0121... \n",
+ "1 [-0.019180298, -0.037384033, 0.0027389526, -0.... \n",
+ "2 [0.031097412, 0.0007619858, -0.023010254, -0.0... \n",
+ "3 [0.015602112, -0.016143799, 0.032958984, 0.000... "
+ ]
+ },
+ "execution_count": 329,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "db"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Tools \n",
+ "\n",
+ "Following functions and tools will be used in the subsequent tasks. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 220,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def retrieve_documents(query: str, n=1) -> dict:\n",
+ " \"\"\"\n",
+ " Function to retrieve documents a given query.\n",
+ "\n",
+ " Steps:\n",
+ " 1. Embed the query\n",
+ " 2. Calculate cosine similarity between the query embedding and the embeddings of the documents\n",
+ " 3. Return the top n documents with the highest similarity scores\n",
+ " \"\"\"\n",
+ " query_emb = co.embed(\n",
+ " texts=[query], model=\"embed-english-v3.0\", input_type=\"search_query\"\n",
+ " )\n",
+ "\n",
+ " similarity_scores = cosine_similarity(\n",
+ " [query_emb.embeddings[0]], db.embeddings.tolist()\n",
+ " )\n",
+ " similarity_scores = similarity_scores[0]\n",
+ "\n",
+ " top_indices = similarity_scores.argsort()[::-1][:n]\n",
+ " top_matches = db.iloc[top_indices]\n",
+ "\n",
+ " return {\"top_matched_document\": top_matches.combined}\n",
+ "\n",
+ "\n",
+ "functions_map = {\n",
+ " \"retrieve_documents\": retrieve_documents,\n",
+ "}\n",
+ "\n",
+ "tools = [\n",
+ " {\n",
+ " \"name\": \"retrieve_documents\",\n",
+ " \"description\": \"given a query, retrieve documents from a database to answer user's question\",\n",
+ " \"parameter_definitions\": {\n",
+ " \"query\": {\"description\": \"query\", \"type\": \"str\", \"required\": True}\n",
+ " },\n",
+ " }\n",
+ "]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## RAG function \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 237,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def simple_rag(query, db):\n",
+ " \"\"\"\n",
+ " Given user's query, retrieve top documents and generate response using documents parameter.\n",
+ " \"\"\"\n",
+ " top_matched_document = retrieve_documents(query)[\"top_matched_document\"]\n",
+ "\n",
+ " print(\"top_matched_document\", top_matched_document)\n",
+ "\n",
+ " output = co.chat(\n",
+ " message=query, model=COHERE_MODEL, documents=[top_matched_document]\n",
+ " )\n",
+ "\n",
+ " return output.text\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Agentic RAG - cohere_agent()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 325,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cohere_agent(\n",
+ " message: str,\n",
+ " preamble: str,\n",
+ " tools: list[dict],\n",
+ " force_single_step=False,\n",
+ " verbose: bool = False,\n",
+ " temperature: float = 0.3,\n",
+ ") -> str:\n",
+ " \"\"\"\n",
+ " Function to handle multi-step tool use api.\n",
+ "\n",
+ " Args:\n",
+ " message (str): The message to send to the Cohere AI model.\n",
+ " preamble (str): The preamble or context for the conversation.\n",
+ " tools (list of dict): List of tools to use in the conversation.\n",
+ " verbose (bool, optional): Whether to print verbose output. Defaults to False.\n",
+ "\n",
+ " Returns:\n",
+ " str: The final response from the call.\n",
+ " \"\"\"\n",
+ "\n",
+ " counter = 1\n",
+ "\n",
+ " response = co.chat(\n",
+ " model=COHERE_MODEL,\n",
+ " message=message,\n",
+ " preamble=preamble,\n",
+ " tools=tools,\n",
+ " force_single_step=force_single_step,\n",
+ " temperature=temperature,\n",
+ " )\n",
+ "\n",
+ " if verbose:\n",
+ " print(f\"\\nrunning 0th step.\")\n",
+ " print(response.text)\n",
+ "\n",
+ " while response.tool_calls:\n",
+ " tool_results = []\n",
+ "\n",
+ " if verbose:\n",
+ " print(f\"\\nrunning {counter}th step.\")\n",
+ "\n",
+ " for tool_call in response.tool_calls:\n",
+ " output = functions_map[tool_call.name](**tool_call.parameters)\n",
+ " outputs = [output]\n",
+ " tool_results.append({\"call\": tool_call, \"outputs\": outputs})\n",
+ "\n",
+ " if verbose:\n",
+ " print(\n",
+ " f\"= running tool {tool_call.name}, with parameters: \\n{tool_call.parameters}\"\n",
+ " )\n",
+ " print(f\"== tool results:\")\n",
+ " pprint(output)\n",
+ "\n",
+ " response = co.chat(\n",
+ " model=COHERE_MODEL,\n",
+ " message=\"\",\n",
+ " chat_history=response.chat_history,\n",
+ " preamble=preamble,\n",
+ " tools=tools,\n",
+ " force_single_step=force_single_step,\n",
+ " tool_results=tool_results,\n",
+ " temperature=temperature,\n",
+ " )\n",
+ "\n",
+ " if verbose:\n",
+ " print(response.text)\n",
+ " counter += 1\n",
+ "\n",
+ " return response.text\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Question 1 - single-stage retrieval \n",
+ "\n",
+ "Here we are asking a question that can be answered easily with single-stage retrieval. Both regular and agentic RAG should be able to answer this question easily. Below is the comparsion of the response. \n",
+ "\n",
+ "| Question | Simple Rag | Agentic Rag | \n",
+ "|--------------|---------|---------|\n",
+ "| Is there a state level law for wearing helmets? | There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use. | There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles. | \n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 179,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "question1 = \"Is there a state level law for wearing helmets?\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Simple RAG "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 180,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "top_matched_document 1 Title: Bicycle helmet requirement\\nBody: Curre...\n",
+ "Name: combined, dtype: object\n",
+ "There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use.\n"
+ ]
+ }
+ ],
+ "source": [
+ "output = simple_rag(question1, db)\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Agentic RAG "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 183,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "running 0th step.\n",
+ "I will search for 'state level law for wearing helmets' in the documents provided and write an answer based on what I find.\n",
+ "\n",
+ "running 1th step.\n",
+ "= running tool retrieve_documents, with parameters: \n",
+ "{'query': 'state level law for wearing helmets'}\n",
+ "== tool results:\n",
+ "{'top_matched_document': 1 Title: Bicycle helmet requirement\\nBody: Curre...\n",
+ "Name: combined, dtype: object}\n",
+ "There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles.\n"
+ ]
+ }
+ ],
+ "source": [
+ "preamble = \"\"\"\n",
+ "You are an expert assistant that helps users answers question about legal documents and policies.\n",
+ "Use the provided documents to answer questions about an employee's specific situation.\n",
+ "\"\"\"\n",
+ "\n",
+ "output = cohere_agent(question1, preamble, tools, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Question 2 - double-stage retrieval \n",
+ "\n",
+ "The second question requires a double-stage retrieval because top matched document references another document. You will see below that the agentic RAG is unable to produce the correct answer initially. But when given proper tools and instructions, it finds the correct answer. \n",
+ "\n",
+ "\n",
+ "| Question | Simple Rag | Agentic Rag | \n",
+ "|--------------|---------|---------|\n",
+ "|I live in orting, do I need to wear a helmet with a bike?|In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.|Yes, you do need to wear a helmet with a bike in Orting if you are under 17.|"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 188,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "question2 = \"I live in orting, do I need to wear a helmet with a bike?\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Simple RAG "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 189,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "top_matched_document 1 Title: Bicycle helmet requirement\\nBody: Curre...\n",
+ "Name: combined, dtype: object\n",
+ "In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.\n"
+ ]
+ }
+ ],
+ "source": [
+ "output = simple_rag(question2, db)\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Agentic RAG\n",
+ "\n",
+ "Produces same quality answer as the simple rag."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 190,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "running 0th step.\n",
+ "I will search for 'helmet with a bike' and then write an answer.\n",
+ "\n",
+ "running 1th step.\n",
+ "= running tool retrieve_documents, with parameters: \n",
+ "{'query': 'helmet with a bike'}\n",
+ "== tool results:\n",
+ "{'top_matched_document': 1 Title: Bicycle helmet requirement\\nBody: Curre...\n",
+ "Name: combined, dtype: object}\n",
+ "There is no state law requiring helmet use, however, some cities and counties do require helmet use with bicycles. I cannot find any information about Orting specifically, but you should check with your local authority.\n"
+ ]
+ }
+ ],
+ "source": [
+ "preamble = \"\"\"\n",
+ "You are an expert assistant that helps users answers question about legal documents and policies.\n",
+ "Use the provided documents to answer questions about an employee's specific situation.\n",
+ "\"\"\"\n",
+ "\n",
+ "output = cohere_agent(question2, preamble, tools, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Agentic RAG - New Tools \n",
+ "\n",
+ "In order for the model to retrieve correct documents, we do two things: \n",
+ "1. New reference_extractor() function is added. This function finds the references to other documents when given query and documents. \n",
+ "2. We update the instruction that directs the agent to keep retrieving relevant documents. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 247,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def reference_extractor(query: str, documents: list[str]) -> str:\n",
+ " \"\"\"\n",
+ " Given a query and document, find references to other documents.\n",
+ " \"\"\"\n",
+ " prompt = f\"\"\"\n",
+ " # instruction\n",
+ " Does the reference document mention any other documents? If so, list them.\n",
+ " If not, return empty string.\n",
+ "\n",
+ " # user query\n",
+ " {query}\n",
+ "\n",
+ " # retrieved documents\n",
+ " {documents}\n",
+ " \"\"\"\n",
+ "\n",
+ " return co.chat(message=prompt, model=COHERE_MODEL, preamble=None).text\n",
+ "\n",
+ "\n",
+ "def retrieve_documents(query: str, n=1) -> dict:\n",
+ " \"\"\"\n",
+ " Function to retrieve most relevant documents a given query.\n",
+ " It also returns other references mentioned in the top matched documents.\n",
+ " \"\"\"\n",
+ " query_emb = co.embed(\n",
+ " texts=[query], model=\"embed-english-v3.0\", input_type=\"search_query\"\n",
+ " )\n",
+ "\n",
+ " similarity_scores = cosine_similarity(\n",
+ " [query_emb.embeddings[0]], db.embeddings.tolist()\n",
+ " )\n",
+ " similarity_scores = similarity_scores[0]\n",
+ "\n",
+ " top_indices = similarity_scores.argsort()[::-1][:n]\n",
+ " top_matches = db.iloc[top_indices]\n",
+ " other_references = reference_extractor(query, top_matches.combined.tolist())\n",
+ "\n",
+ " return {\n",
+ " \"top_matched_document\": top_matches.combined,\n",
+ " \"other_references_to_query\": other_references,\n",
+ " }\n",
+ "\n",
+ "\n",
+ "functions_map = {\n",
+ " \"retrieve_documents\": retrieve_documents,\n",
+ "}\n",
+ "\n",
+ "tools = [\n",
+ " {\n",
+ " \"name\": \"retrieve_documents\",\n",
+ " \"description\": \"given a query, retrieve documents from a database to answer user's question. It also finds references to other documents that should be leveraged to retrieve more documents\",\n",
+ " \"parameter_definitions\": {\n",
+ " \"query\": {\n",
+ " \"description\": \"user's question or question or name of other document sections or references.\",\n",
+ " \"type\": \"str\",\n",
+ " \"required\": True,\n",
+ " }\n",
+ " },\n",
+ " }\n",
+ "]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 249,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "running 0th step.\n",
+ "I will search for 'Orting' and 'bike helmet' to find the relevant information.\n",
+ "\n",
+ "running 1th step.\n",
+ "= running tool retrieve_documents, with parameters: \n",
+ "{'query': 'Orting bike helmet'}\n",
+ "== tool results:\n",
+ "{'other_references_to_query': 'Section 21a, Section 3b',\n",
+ " 'top_matched_document': 0 Title: Bicycle law\\nBody: \\n Riding on ...\n",
+ "Name: combined, dtype: object}\n",
+ "I have found that there is no state law requiring helmet use, but some cities and counties do require helmets. I will now search for 'Section 21a' to find out if Orting is one of these cities or counties.\n",
+ "\n",
+ "running 2th step.\n",
+ "= running tool retrieve_documents, with parameters: \n",
+ "{'query': 'Section 21a'}\n",
+ "== tool results:\n",
+ "{'other_references_to_query': '- Section 3b',\n",
+ " 'top_matched_document': 2 Title: Section 21a\\nBody: helmet rules by loca...\n",
+ "Name: combined, dtype: object}\n",
+ "Yes, you do need to wear a helmet when riding a bike in Orting if you are under 17.\n"
+ ]
+ }
+ ],
+ "source": [
+ "preamble2 = \"\"\"# Instruction\n",
+ "You are an expert assistant that helps users answers question about legal documents and policies.\n",
+ "\n",
+ "Please follow these steps:\n",
+ "1. Using user's query, use `retrieve_documents` tool to retrieve the most relevant document from the database.\n",
+ "2. If you see `other_references_to_query` in the tool result, search the mentioned referenced using `retrieve_documents()` tool to retrieve more documents.\n",
+ "3. Keep trying until you find the answer.\n",
+ "4. Answer with yes or no as much as you can to answer the question directly.\n",
+ "\"\"\"\n",
+ "\n",
+ "output = cohere_agent(question2, preamble2, tools, verbose=True)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/agents/agentic-RAG/agentic_rag_publication.ipynb b/notebooks/agents/agentic-RAG/agentic_rag_langchain.ipynb
similarity index 100%
rename from notebooks/agents/agentic-RAG/agentic_rag_publication.ipynb
rename to notebooks/agents/agentic-RAG/agentic_rag_langchain.ipynb