diff --git a/gemini/reasoning-engine/tutorial_graph_rag.ipynb b/gemini/reasoning-engine/tutorial_graph_rag.ipynb
new file mode 100644
index 0000000000..13b733568e
--- /dev/null
+++ b/gemini/reasoning-engine/tutorial_graph_rag.ipynb
@@ -0,0 +1,1636 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TQ1nQmSbn9co"
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright 2023 Google LLC\n",
+ "#\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "skeDTD9moBK5"
+ },
+ "source": [
+ "# GraphRAG on Google Cloud\n",
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " Run in Colab\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " Run in Colab Enterprise\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " View on GitHub\n",
+ " \n",
+ " | \n",
+ " \n",
+ " \n",
+ " Open in Vertex AI Workbench\n",
+ " \n",
+ " | \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "714f683ef292"
+ },
+ "source": [
+ "| | |\n",
+ "|-|-|\n",
+ "|Author(s) | [Tristan Li](https://github.com/codingphun), Ashish Chauhan, [Smitha Venkat](https://github.com/smitha-google)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tC8nObmLw9P0"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "[LangChain on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/reasoning-engine/overview)\n",
+ "is a managed service that helps you to build and deploy LangChain apps to a managed Reasoning Engine runtime.\n",
+ "\n",
+ "Instead of simply retrieving relevant text snippets based on keyword similarity, GraphRAG takes a more sophisticated, structured approach to Retrieval Augmented Generation. It involves creating a knowledge graph from the text, organizing it hierarchically, summarizing key concepts, and then using this structured information to enhance the accuracy and depth of responses.\n",
+ "\n",
+ "## Objectives\n",
+ "\n",
+ "In this tutorial, you will see a complete walkthrough of building a question-answering system using the GraphRAG method. You'll learn how to create a knowledge graph from scratch, store it efficiently in Spanner Graph, enhance search accuracy with embedding vectors in Spanner Vector Database, and finally, deploy a functional FAQ system with LangChain and Vertex AI Reasoning Engine."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Wvrm1CcB2DVv"
+ },
+ "source": [
+ "![Architecture](https://storage.googleapis.com/github-repo/generative-ai/gemini/reasoning-engine/images/graphrag.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "de74e872cec0"
+ },
+ "source": [
+ "## Before you begin\n",
+ "\n",
+ "1. In the Google Cloud console, on the project selector page, select or [create a Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects).\n",
+ "1. [Make sure that billing is enabled for your Google Cloud project](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled#console).\n",
+ "\n",
+ "### Required roles\n",
+ "\n",
+ "To get the permissions that you need to complete the tutorial, ask your administrator to grant you the [Owner](https://cloud.google.com/iam/docs/understanding-roles#owner) (`roles/owner`) IAM role on your project. For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "S4irrymMr2LX"
+ },
+ "source": [
+ "## Getting Started"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "arJM4CK4r6cj"
+ },
+ "source": [
+ "### Install Python Libraries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yhWJlYAmXSUq"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install --quiet --force-reinstall langchain==0.3.0\n",
+ "!pip install --upgrade --quiet json-repair networkx==3.3 langchain-core==0.3.2 langchain-google-vertexai==2.0.1 langchain-experimental==0.3.0 langchain-community==0.3.0 langchain-text-splitters==0.3.0\n",
+ "\n",
+ "!pip install --quiet google-cloud-aiplatform==1.67.0\n",
+ "!pip install --quiet google-cloud-resource-manager==1.12.5\n",
+ "!pip install --quiet google-cloud-spanner==3.48.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zTa3RFWesfsL"
+ },
+ "source": [
+ "### Restart the Kernel"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "BXJ7SzTYXr_y"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'status': 'ok', 'restart': True}"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import IPython\n",
+ "\n",
+ "app = IPython.Application.instance()\n",
+ "app.kernel.do_shutdown(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DoK1gW9dsRxB"
+ },
+ "source": [
+ "### Authenticating your notebook environment\n",
+ "* If you are using **Colab** to run this notebook, uncomment the cell below and continue.\n",
+ "* If you are using **Vertex AI Workbench**, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "os3H39sGXugN"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "\n",
+ "if \"google.colab\" in sys.modules:\n",
+ " from google.colab import auth as google_auth\n",
+ "\n",
+ " google_auth.authenticate_user()\n",
+ "print(sys.version)\n",
+ "# If using local jupyter instance, uncomment and run:\n",
+ "# !gcloud auth login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CzaWjgsqsuuu"
+ },
+ "source": [
+ "### CHANGE the following settings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "zjz6i4vAXvZg"
+ },
+ "outputs": [],
+ "source": [
+ "GCP_PROJECT_ID = \"\"\n",
+ "GCP_PROJECT_NUMBER = \"\"\n",
+ "REGION = \"us-central1\"\n",
+ "STAGING_BUCKET = \"gs://\" # must be at root bucket level and not a subfolder\n",
+ "MODEL_NAME = \"gemini-1.5-pro-002\"\n",
+ "EMBEDDING_MODEL_NAME = \"text-embedding-004\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GHyZn3Lns9YM"
+ },
+ "source": [
+ "### Import Packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "mCQuCvfrXxUM"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.chains import GraphQAChain\n",
+ "from langchain_community.graphs.networkx_graph import NetworkxEntityGraph\n",
+ "from langchain_core.documents import Document\n",
+ "from langchain_experimental.graph_transformers import LLMGraphTransformer\n",
+ "from langchain_google_vertexai import VertexAI\n",
+ "import matplotlib.pyplot as plt\n",
+ "import networkx as nx"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xSuWYhb62UJS"
+ },
+ "source": [
+ "### Sample Texts\n",
+ "\n",
+ "These texts extracted from Wikipedia are about Larry Page, co-founder of Google. These texts will be used to create a knowledge graph about Larry Page as well as embedding vectors for semantic search."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "NIokGkZii43z"
+ },
+ "outputs": [],
+ "source": [
+ "TEXT_1 = \"Lawrence Edward Page (born March 26, 1973) is an American businessman and computer scientist best known for co-founding Google with Sergey Brin. \"\n",
+ "TEXT_2 = \"Lawrence Edward Page was chief executive officer of Google from 1997 until August 2001 when he stepped down in favor of Eric Schmidt, and then again from April 2011 until July 2015 when he became CEO of its newly formed parent organization Alphabet Inc.[6] He held that post until December 4, 2019, when he and Brin stepped down from all executive positions and day-to-day roles within the company. He remains an Alphabet board member, employee, and controlling shareholder.\"\n",
+ "TEXT_3 = \"Lawrence Edward Page has an estimated net worth of $156 billion as of June 2024, according to the Bloomberg Billionaires Index, and $145.2 billion according to Forbes, making him the fifth-richest person in the world. He has also invested in flying car startups Kitty Hawk and Opener.\"\n",
+ "TEXT_4 = \"Like his Google co-founder, Sergey Brin, Page attended Montessori schools until he entered high school. They both cite the educational method of Maria Montessori as the major influence in how they designed Google's work systems. Maria Montessori believed that the liberty of the child was of utmost importance. In some sense, I feel like music training led to the high-speed legacy of Google for me\"\n",
+ "\n",
+ "text = TEXT_1 + TEXT_2 + TEXT_3 + TEXT_4"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YDpwl48k27HU"
+ },
+ "source": [
+ "### Create Knowledge Graph\n",
+ "\n",
+ "We will use Gemini and LangChain LLMGraphTransformer to parse the texts and generate a knowledge graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "ZIcYPheujsmc"
+ },
+ "outputs": [],
+ "source": [
+ "llm = VertexAI(\n",
+ " max_output_tokens=4000,\n",
+ " model_name=MODEL_NAME,\n",
+ " project=GCP_PROJECT_ID,\n",
+ " location=REGION,\n",
+ ")\n",
+ "\n",
+ "documents = [Document(page_content=text)]\n",
+ "llm_transformer = LLMGraphTransformer(llm=llm)\n",
+ "graph_documents = llm_transformer.convert_to_graph_documents(documents)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "old4BF_L3ZKP"
+ },
+ "source": [
+ "Leveraging Gemini's capabilities, LangChain will use them to identify and extract key information from the text, such as people, countries, and their nationalities, to construct a comprehensive knowledge graph from the texts based on the nodes and relationships we define."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "ppkSR-_8jzPO"
+ },
+ "outputs": [],
+ "source": [
+ "llm_transformer_filtered = LLMGraphTransformer(\n",
+ " llm=llm,\n",
+ " allowed_nodes=[\"Person\", \"Country\", \"Organization\", \"Asset\"],\n",
+ " allowed_relationships=[\n",
+ " \"NATIONALITY\",\n",
+ " \"LOCATED_IN\",\n",
+ " \"WORKED_AT\",\n",
+ " \"SPOUSE\",\n",
+ " \"NET_WORTH\",\n",
+ " \"INVESTMENT\",\n",
+ " \"INFLUENCED_BY\",\n",
+ " ],\n",
+ ")\n",
+ "graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(\n",
+ " documents\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7LIsSsRa38NO"
+ },
+ "source": [
+ "Create a knowledge graph from the nodes and relationships extracted by Gemini"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "1SCw5-iLj53q"
+ },
+ "outputs": [],
+ "source": [
+ "graph = NetworkxEntityGraph()\n",
+ "\n",
+ "# Add nodes to the graph\n",
+ "for node in graph_documents_filtered[0].nodes:\n",
+ " graph.add_node(node.id)\n",
+ "\n",
+ "# Add edges to the graph\n",
+ "for edge in graph_documents_filtered[0].relationships:\n",
+ " graph._graph.add_edge(\n",
+ " edge.source.id,\n",
+ " edge.target.id,\n",
+ " relation=edge.type,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1Lsl-dTbrT3Z"
+ },
+ "source": [
+ "Let's visualize the Generated Graph\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "LkcxVs5irXW_"
+ },
+ "outputs": [],
+ "source": [
+ "def visualize(graph):\n",
+ " G = graph._graph\n",
+ " # G.add_edges_from(self.visual)\n",
+ " nx.draw_networkx(G, with_labels=True)\n",
+ " plt.show(block=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "1JEUbP8rrZpx"
+ },
+ "outputs": [],
+ "source": [
+ "visualize(graph)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5YBi-qaV4Hfp"
+ },
+ "source": [
+ "Now let's build a simple Graph QA chain to ask some questions based on the knowledge graph created."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "JrGjHvf0j7vu"
+ },
+ "outputs": [],
+ "source": [
+ "chain = GraphQAChain.from_llm(llm=llm, graph=graph, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9ko0lPr14SAS"
+ },
+ "source": [
+ "Notice you will not get an expected response back, this is because we asked for Larry Page, not Lawrance Edward Page which was extracted from the texts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "wwDND4b_j8cg"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/tmp/ipykernel_3467/2623240478.py:2: LangChainDeprecationWarning: The method `Chain.run` was deprecated in langchain 0.1.0 and will be removed in 1.0. Use :meth:`~invoke` instead.\n",
+ " chain.run(question)\n",
+ "Error in StdOutCallbackHandler.on_chain_start callback: AttributeError(\"'NoneType' object has no attribute 'get'\")\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Entities Extracted:\n",
+ "\u001b[32;1m\u001b[1;3mLarry Page\n",
+ "\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m\u001b[0m\n",
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "\"I don't know. The provided triplets don't mention any influences on Larry Page.\\n\""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "question = \"\"\"\"Who influenced Larry Page?\"\"\"\n",
+ "chain.run(question)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xisGWih84sCK"
+ },
+ "source": [
+ "Now if we rephrase the question and ask Lawrence Edward Page, which was extracted from the texts, it will work. Normally a typical user will not ask Larry's full legal name, so how can we solve this issue? The answer is semantic search through embedding and vector search."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "id": "3586Aq-a4tTK"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Error in StdOutCallbackHandler.on_chain_start callback: AttributeError(\"'NoneType' object has no attribute 'get'\")\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Entities Extracted:\n",
+ "\u001b[32;1m\u001b[1;3mLawrence Edward Page\n",
+ "\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3mLawrence Edward Page NET_WORTH $156 billion\n",
+ "Lawrence Edward Page NET_WORTH $145.2 billion\n",
+ "Lawrence Edward Page WORKED_AT Google\n",
+ "Lawrence Edward Page WORKED_AT Alphabet Inc.\n",
+ "Lawrence Edward Page INVESTMENT Kitty Hawk\n",
+ "Lawrence Edward Page INVESTMENT Opener\n",
+ "Lawrence Edward Page INFLUENCED_BY Maria Montessori\u001b[0m\n",
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'Maria Montessori\\n'"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "question = \"\"\"\"Who influenced Lawrence Edward Page?\"\"\"\n",
+ "chain.run(question)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Axb7c8Y1YmQ8"
+ },
+ "source": [
+ "### Create Spanner Instance and Database\n",
+ "\n",
+ "To prepare for future queries, we'll now store our newly created knowledge graph in a Google Cloud Spanner database. We'll also store the accompanying embeddings in Spanner's Vector Database to enable efficient semantic search.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "id": "EO1vfftwqs4Z"
+ },
+ "outputs": [],
+ "source": [
+ "SPANNER_INSTANCE_ID = \"graphrag-instance\"\n",
+ "SPANNER_DATABASE_ID = \"graphrag\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "id": "27gTtXr4m2n2"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Updated property [core/project].\n",
+ "Creating instance...done. \n"
+ ]
+ }
+ ],
+ "source": [
+ "!gcloud config set project {GCP_PROJECT_ID}\n",
+ "!gcloud services enable spanner.googleapis.com\n",
+ "!gcloud spanner instances create {SPANNER_INSTANCE_ID} --config=regional-us-central1 --description=\"Graph RAG Instance\" --nodes=1 --edition=ENTERPRISE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "id": "jsoqBPlDthc7"
+ },
+ "outputs": [],
+ "source": [
+ "# prompt: create a spanner database and table to store the graph with nodes and edges created in graph\n",
+ "\n",
+ "\n",
+ "def create_database(project_id, instance_id, database_id):\n",
+ " \"\"\"Creates a database and tables for sample data.\"\"\"\n",
+ " from google.cloud import spanner\n",
+ " from google.cloud.spanner_admin_database_v1.types import spanner_database_admin\n",
+ "\n",
+ " spanner_client = spanner.Client(project_id)\n",
+ " database_admin_api = spanner_client.database_admin_api\n",
+ "\n",
+ " request = spanner_database_admin.CreateDatabaseRequest(\n",
+ " parent=database_admin_api.instance_path(spanner_client.project, instance_id),\n",
+ " create_statement=f\"CREATE DATABASE `{database_id}`\",\n",
+ " extra_statements=[\n",
+ " \"\"\"CREATE TABLE Person (\n",
+ " person STRING(1024)\n",
+ " ) PRIMARY KEY (person)\"\"\",\n",
+ " \"\"\"CREATE TABLE Country (\n",
+ " country STRING(1024))\n",
+ " PRIMARY KEY (country)\"\"\",\n",
+ " \"\"\"CREATE TABLE Organization (\n",
+ " organization STRING(1024)\n",
+ " ) PRIMARY KEY (organization)\"\"\",\n",
+ " \"\"\"CREATE TABLE Asset (\n",
+ " asset STRING(MAX)\n",
+ " ) PRIMARY KEY (asset)\"\"\",\n",
+ " \"\"\"CREATE TABLE KgNode (\n",
+ " DocId INT64 NOT NULL,\n",
+ " Name STRING(1024),\n",
+ " DOC STRING(1024),\n",
+ " DocEmbedding ARRAY\n",
+ " ) PRIMARY KEY (DocId)\"\"\",\n",
+ " \"\"\"CREATE TABLE NATIONALITY (\n",
+ " P_Name STRING(1024) ,\n",
+ " C_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (C_Name) REFERENCES Country (country)\n",
+ " ) PRIMARY KEY (P_Name, C_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE LOCATED_IN (\n",
+ " O1_Name STRING(1024) ,\n",
+ " O2_Name STRING(1024) ,\n",
+ " FOREIGN KEY (O1_Name) REFERENCES Organization (organization),\n",
+ " FOREIGN KEY (O2_Name) REFERENCES Organization (organization)\n",
+ " ) PRIMARY KEY (O1_Name, O2_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE WORKED_AT(\n",
+ " P_Name STRING(1024) ,\n",
+ " O_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (O_Name) REFERENCES Organization (organization)\n",
+ " ) PRIMARY KEY (P_Name, O_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE SPOUSE (\n",
+ " P1_Name STRING(1024) ,\n",
+ " P2_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P1_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (P2_Name) REFERENCES Person (person)\n",
+ " ) PRIMARY KEY (P1_Name, P2_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE NET_WORTH(\n",
+ " P_Name STRING(1024) ,\n",
+ " A_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (A_Name) REFERENCES Asset (asset)\n",
+ " ) PRIMARY KEY (P_Name, A_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE INVESTMENT(\n",
+ " P_Name STRING(1024) ,\n",
+ " O_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (O_Name) REFERENCES Organization (organization)\n",
+ " ) PRIMARY KEY (P_Name, O_Name)\"\"\",\n",
+ " \"\"\"CREATE TABLE INFLUENCED_BY(\n",
+ " P1_Name STRING(1024) ,\n",
+ " P2_Name STRING(1024) ,\n",
+ " FOREIGN KEY (P1_Name) REFERENCES Person (person),\n",
+ " FOREIGN KEY (P2_Name) REFERENCES Person (person)\n",
+ " ) PRIMARY KEY (P1_Name, P2_Name)\"\"\",\n",
+ " \"\"\"CREATE OR REPLACE PROPERTY GRAPH User\n",
+ " NODE TABLES (Person, Country, Organization, Asset)\n",
+ " EDGE TABLES (\n",
+ " NATIONALITY\n",
+ " SOURCE KEY (P_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (C_Name) REFERENCES Country (country)\n",
+ " LABEL nationality,\n",
+ " LOCATED_IN\n",
+ " SOURCE KEY (O1_Name) REFERENCES Organization (organization)\n",
+ " DESTINATION KEY (O2_Name) REFERENCES Organization (organization)\n",
+ " LABEL located,\n",
+ " WORKED_AT\n",
+ " SOURCE KEY (P_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (O_Name) REFERENCES Organization (organization)\n",
+ " LABEL worked,\n",
+ " SPOUSE\n",
+ " SOURCE KEY (P1_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (P2_Name) REFERENCES Person (person)\n",
+ " LABEL spouse,\n",
+ " NET_WORTH\n",
+ " SOURCE KEY (P_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (A_Name) REFERENCES Asset (asset)\n",
+ " LABEL net_worth,\n",
+ " INVESTMENT\n",
+ " SOURCE KEY (P_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (O_Name) REFERENCES Organization (organization)\n",
+ " LABEL invested,\n",
+ " INFLUENCED_BY\n",
+ " SOURCE KEY (P1_Name) REFERENCES Person (person)\n",
+ " DESTINATION KEY (P2_Name) REFERENCES Person (person)\n",
+ " LABEL influenced)\"\"\",\n",
+ " ],\n",
+ " )\n",
+ "\n",
+ " operation = database_admin_api.create_database(request=request)\n",
+ "\n",
+ " print(\"Waiting for operation to complete...\")\n",
+ " OPERATION_TIMEOUT_SECONDS = 60\n",
+ " database = operation.result(OPERATION_TIMEOUT_SECONDS)\n",
+ "\n",
+ " print(\n",
+ " \"Created database {} on instance {}\".format(\n",
+ " database.name,\n",
+ " database_admin_api.instance_path(spanner_client.project, instance_id),\n",
+ " )\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7mVK78UWt2OJ"
+ },
+ "outputs": [],
+ "source": [
+ "from google.cloud import spanner\n",
+ "\n",
+ "create_database(GCP_PROJECT_ID, SPANNER_INSTANCE_ID, SPANNER_DATABASE_ID)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wmZK7O3x56iN"
+ },
+ "source": [
+ "Now that the Spanner instance, database and tables are created. We will loop through the graph nodes and relationship in the knowledge graph and store them in temperary arrays to insert them into the Spanner database."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "id": "28RVIpzndL3c"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[['Person', 'Lawrence Edward Page'], ['Organization', 'Google'], ['Organization', 'Opener'], ['Person', 'Sergey Brin'], ['Organization', 'Alphabet Inc.'], ['Person', 'Maria Montessori'], ['Asset', '$156 billion'], ['Organization', 'Kitty Hawk'], ['Asset', '$145.2 billion']]\n",
+ "[['NET_WORTH', 'Person', 'Lawrence Edward Page', 'Asset', '$156 billion'], ['NET_WORTH', 'Person', 'Lawrence Edward Page', 'Asset', '$145.2 billion'], ['WORKED_AT', 'Person', 'Lawrence Edward Page', 'Organization', 'Google'], ['WORKED_AT', 'Person', 'Lawrence Edward Page', 'Organization', 'Alphabet Inc.'], ['INVESTMENT', 'Person', 'Lawrence Edward Page', 'Organization', 'Kitty Hawk'], ['INVESTMENT', 'Person', 'Lawrence Edward Page', 'Organization', 'Opener'], ['INFLUENCED_BY', 'Person', 'Lawrence Edward Page', 'Person', 'Maria Montessori'], ['WORKED_AT', 'Person', 'Sergey Brin', 'Organization', 'Google'], ['INFLUENCED_BY', 'Person', 'Sergey Brin', 'Person', 'Maria Montessori']]\n"
+ ]
+ }
+ ],
+ "source": [
+ "for graph_d in graph_documents_filtered:\n",
+ " graph = graph_d.nodes\n",
+ " temp = []\n",
+ " # Access nodes\n",
+ " for node in graph_documents_filtered[0].nodes:\n",
+ " temp.append(\n",
+ " [node.type, node.id]\n",
+ " ) # Access node properties (e.g., entity type, labels)\n",
+ " temp2 = []\n",
+ " # Access relationships\n",
+ " for relationship in graph_documents_filtered[0].relationships:\n",
+ " temp2.append(\n",
+ " [\n",
+ " relationship.type,\n",
+ " relationship.source.type,\n",
+ " relationship.source.id,\n",
+ " relationship.target.type,\n",
+ " relationship.target.id,\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "print(temp)\n",
+ "print(temp2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HFwPWatO6PaJ"
+ },
+ "source": [
+ "We will also embedd the same texts and store the embedding vector in Spanner Vector database for semantic search later."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BEviClp6M8Cg"
+ },
+ "outputs": [],
+ "source": [
+ "import vertexai\n",
+ "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
+ "\n",
+ "# init the vertexai package\n",
+ "vertexai.init(project=GCP_PROJECT_ID, location=REGION)\n",
+ "\n",
+ "\n",
+ "def get_embedding(text, task_type, model):\n",
+ " try:\n",
+ " text_embedding_input = TextEmbeddingInput(task_type=task_type, text=text)\n",
+ " embeddings = model.get_embeddings([text_embedding_input])\n",
+ " return embeddings[0].values\n",
+ " except:\n",
+ " return []\n",
+ "\n",
+ "\n",
+ "TASK_TYPE = \"QUESTION_ANSWERING\"\n",
+ "ANSWER_TASK_TYPE = \"RETRIEVAL_DOCUMENT\"\n",
+ "EMBEDDING_MODEL = TextEmbeddingModel.from_pretrained(EMBEDDING_MODEL_NAME)\n",
+ "\n",
+ "a1_emb = get_embedding(TEXT_1, ANSWER_TASK_TYPE, EMBEDDING_MODEL)\n",
+ "a2_emb = get_embedding(TEXT_2, ANSWER_TASK_TYPE, EMBEDDING_MODEL)\n",
+ "a3_emb = get_embedding(TEXT_3, ANSWER_TASK_TYPE, EMBEDDING_MODEL)\n",
+ "a4_emb = get_embedding(TEXT_4, ANSWER_TASK_TYPE, EMBEDDING_MODEL)\n",
+ "print(a1_emb)\n",
+ "print(a2_emb)\n",
+ "print(a3_emb)\n",
+ "print(a4_emb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "id": "vNZhxFz5N1oW"
+ },
+ "outputs": [],
+ "source": [
+ "TextEmbeddings = [a1_emb, a2_emb, a3_emb, a4_emb]\n",
+ "Texts = [TEXT_1, TEXT_2, TEXT_3, TEXT_4]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kkuZY9-U6YXs"
+ },
+ "source": [
+ "Now we will insert the graph nodes, relationship as well as embedding vectors into the Spanner database."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "i0ADOTSgN6FA"
+ },
+ "outputs": [],
+ "source": [
+ "from google.cloud import spanner\n",
+ "\n",
+ "spanner_client = spanner.Client(GCP_PROJECT_ID)\n",
+ "instance = spanner_client.instance(SPANNER_INSTANCE_ID)\n",
+ "database = instance.database(SPANNER_DATABASE_ID)\n",
+ "\n",
+ "\n",
+ "def insert_values(transaction):\n",
+ " for sub_list in temp:\n",
+ " table_name = sub_list[0]\n",
+ " col_name = (sub_list[0]).lower()\n",
+ " value = sub_list[1]\n",
+ " row_ct = transaction.execute_update(\n",
+ " f\"INSERT INTO {table_name} ({col_name}) VALUES (@value)\",\n",
+ " params={\"value\": value},\n",
+ " param_types={\"value\": spanner.param_types.STRING}, # Adjust type if needed\n",
+ " )\n",
+ " print(f\"{row_ct} record(s) inserted.\")\n",
+ "\n",
+ " for sub_list2 in temp2:\n",
+ " table_name = sub_list2[0]\n",
+ " col_name1 = (sub_list2[1])[0:1] + (\n",
+ " \"1_Name\"\n",
+ " if table_name == \"INFLUENCED_BY\" or table_name == \"LOCATED_IN\"\n",
+ " else \"_Name\"\n",
+ " )\n",
+ " value1 = sub_list2[2]\n",
+ " col_name2 = (sub_list2[3])[0:1] + (\n",
+ " \"2_Name\"\n",
+ " if table_name == \"INFLUENCED_BY\" or table_name == \"LOCATED_IN\"\n",
+ " else \"_Name\"\n",
+ " )\n",
+ " value2 = sub_list2[4]\n",
+ " print(table_name, col_name1, value1, col_name2, value2)\n",
+ "\n",
+ " row_ct1 = transaction.execute_update(\n",
+ " f\"INSERT INTO {table_name} ({col_name1}, {col_name2}) VALUES (@value1, @value2)\",\n",
+ " params={\"value1\": value1, \"value2\": value2},\n",
+ " param_types={\n",
+ " \"value1\": spanner.param_types.STRING,\n",
+ " \"value2\": spanner.param_types.STRING,\n",
+ " }, # Adjust types if needed\n",
+ " )\n",
+ " print(f\"{row_ct1} record(s) inserted.\") # Print the result here\n",
+ " i = 0\n",
+ " for sub_list in TextEmbeddings:\n",
+ " table_name = \"KgNode\"\n",
+ " col_name1 = \"docID\"\n",
+ " col_name2 = \"Name\"\n",
+ " col_name3 = \"Doc\"\n",
+ " col_name4 = \"DocEmbedding\"\n",
+ " d = str(Texts[i])\n",
+ " value1 = i + 1\n",
+ " i = i + 1\n",
+ " value2 = \"Lawrence Edward Page\"\n",
+ " value3 = d\n",
+ " value4 = sub_list\n",
+ " print(\n",
+ " col_name1, col_name2, col_name3, col_name4, value1, value2, value3, value4\n",
+ " )\n",
+ " row_ct1 = transaction.execute_update(\n",
+ " f\"INSERT INTO {table_name} ({col_name1}, {col_name2}, {col_name3}, {col_name4}) VALUES (@value1, @value2, @value3, @value4)\",\n",
+ " params={\n",
+ " \"value1\": value1,\n",
+ " \"value2\": value2,\n",
+ " \"value3\": value3,\n",
+ " \"value4\": value4,\n",
+ " },\n",
+ " param_types={\n",
+ " \"value1\": spanner.param_types.INT64,\n",
+ " \"value2\": spanner.param_types.STRING,\n",
+ " \"value3\": spanner.param_types.STRING,\n",
+ " \"value4\": spanner.param_types.Array(spanner.param_types.FLOAT64),\n",
+ " },\n",
+ " ) # Adjust types if needed\n",
+ "\n",
+ " print(f\"{row_ct1} record(s) inserted.\")\n",
+ "\n",
+ "\n",
+ "# print(insert_values) # This just prints the function object, remove this line\n",
+ "database.run_in_transaction(insert_values)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3wQLk1Zy6rDr"
+ },
+ "source": [
+ "Now the graph nodes, relationship and embedding vectors are stored in the Spanner database. Let's try to query the database. We will ask the original question \"Who influced Larry Page?\" in the Vector database to get a semantic search result. Then we will get the node \"Lawrence Edward Page\" from the result."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "id": "HtTrpMASOfcR"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1, 'Lawrence Edward Page', 'Lawrence Edward Page (born March 26, 1973) is an American businessman and computer scientist best known for co-founding Google with Sergey Brin. ']\n"
+ ]
+ }
+ ],
+ "source": [
+ "QUESTION = \"Who influenced Larry Page?\"\n",
+ "q_emb = get_embedding(QUESTION, TASK_TYPE, EMBEDDING_MODEL)\n",
+ "\n",
+ "spanner_client = spanner.Client(GCP_PROJECT_ID)\n",
+ "instance = spanner_client.instance(SPANNER_INSTANCE_ID)\n",
+ "database = instance.database(SPANNER_DATABASE_ID)\n",
+ "kgnodename = \"\"\n",
+ "with database.snapshot() as snapshot:\n",
+ " results = snapshot.execute_sql(\n",
+ " \"\"\"SELECT DocId, NAME, Doc FROM KgNode ORDER BY COSINE_DISTANCE(DocEmbedding, @q_emb) limit 1\"\"\",\n",
+ " params={\"q_emb\": q_emb},\n",
+ " param_types={\n",
+ " \"q_emb\": spanner.param_types.Array(spanner.param_types.FLOAT64)\n",
+ " }, # Adjust FLOAT64 if needed\n",
+ " )\n",
+ " for row in results:\n",
+ " kgnodename = str(row[1])\n",
+ " print(row)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VUbuSkFS7OY3"
+ },
+ "source": [
+ "Then we will traverse through the knowledge graph with the correct node name \"Lawrence Edward Page\". Now we can uncover all the connections and relationships associated with Larry Page."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "id": "O-lCGcLP1loV"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Lawrence Edward Page WORKED_AT Alphabet Inc.\n",
+ "Lawrence Edward Page WORKED_AT Google\n",
+ "Lawrence Edward Page INVESTMENT Kitty Hawk\n",
+ "Lawrence Edward Page INVESTMENT Opener\n",
+ "Lawrence Edward Page INFLUENCED_BY Maria Montessori\n"
+ ]
+ }
+ ],
+ "source": [
+ "from google.cloud import spanner\n",
+ "\n",
+ "spanner_client = spanner.Client(GCP_PROJECT_ID)\n",
+ "instance = spanner_client.instance(SPANNER_INSTANCE_ID)\n",
+ "database = instance.database(SPANNER_DATABASE_ID)\n",
+ "\n",
+ "\n",
+ "tables = [\"NATIONALITY\", \"WORKED_AT\", \"INVESTMENT\", \"INFLUENCED_BY\"]\n",
+ "a = []\n",
+ "for table in tables:\n",
+ " with database.snapshot() as snapshot:\n",
+ " if table in (\"INFLUENCED_BY\", \"LOCATED_IN\"):\n",
+ " column_name = \"P1_Name\"\n",
+ " else:\n",
+ " column_name = \"P_Name\"\n",
+ "\n",
+ " results = snapshot.execute_sql(\n",
+ " f\"\"\"\n",
+ " SELECT {column_name} AS Name, *\n",
+ " FROM {table}\n",
+ " WHERE {column_name} = '{kgnodename}'\n",
+ " \"\"\"\n",
+ " )\n",
+ " for row in results:\n",
+ " # Dynamically determine the index of the relevant column\n",
+ " name_index = row.index(kgnodename)\n",
+ " # Construct the output string using f-string\n",
+ " output_string = f\"{row[name_index]} {table} {row[2]}\"\n",
+ " print(output_string)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gTdtp14I7P5g"
+ },
+ "source": [
+ "### LangChain Custom Retriever and Chain\n",
+ "\n",
+ "We will create a custom LangChain retriever that will encapsulate code blocks above and chain together to make a functional Q&A bot."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "id": "qzOZ5NUv55vU"
+ },
+ "outputs": [],
+ "source": [
+ "from google.cloud import spanner\n",
+ "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
+ "from langchain_core.documents import Document\n",
+ "from langchain_core.retrievers import BaseRetriever\n",
+ "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
+ "\n",
+ "\n",
+ "class SpannerGraphRagRetriever(BaseRetriever):\n",
+ " gcp_project_id: str\n",
+ " spanner_instance_id: str\n",
+ " spanner_database_id: str\n",
+ " embedding_model_name: str\n",
+ "\n",
+ " def get_embedding(text, task_type, model):\n",
+ " try:\n",
+ " text_embedding_input = TextEmbeddingInput(task_type=task_type, text=text)\n",
+ " embeddings = model.get_embeddings([text_embedding_input])\n",
+ " return embeddings[0].values\n",
+ " except:\n",
+ " return []\n",
+ "\n",
+ " def _get_relevant_documents(\n",
+ " self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
+ " ) -> list[Document]:\n",
+ " # Query the Spanner database and construct the LangChain Documents\n",
+ " spanner_client = spanner.Client(self.gcp_project_id)\n",
+ " instance = spanner_client.instance(self.spanner_instance_id)\n",
+ " database = instance.database(self.spanner_database_id)\n",
+ " embedding_model = TextEmbeddingModel.from_pretrained(self.embedding_model_name)\n",
+ " task_type = \"QUESTION_ANSWERING\"\n",
+ " kgnode_name = \"\"\n",
+ "\n",
+ " q_emb = get_embedding(query, task_type, embedding_model)\n",
+ "\n",
+ " with database.snapshot() as snapshot:\n",
+ " results = snapshot.execute_sql(\n",
+ " \"\"\"SELECT DocId, NAME, Doc FROM KgNode ORDER BY COSINE_DISTANCE(DocEmbedding, @q_emb) limit 1\"\"\",\n",
+ " params={\"q_emb\": q_emb},\n",
+ " param_types={\n",
+ " \"q_emb\": spanner.param_types.Array(spanner.param_types.FLOAT64)\n",
+ " }, # Adjust FLOAT64 if needed\n",
+ " )\n",
+ " for row in results:\n",
+ " kgnode_name = row[1]\n",
+ "\n",
+ " tables = [\"NATIONALITY\", \"WORKED_AT\", \"INVESTMENT\", \"INFLUENCED_BY\"]\n",
+ " documents = []\n",
+ " for table in tables:\n",
+ " with database.snapshot() as snapshot:\n",
+ " if table in (\"INFLUENCED_BY\", \"LOCATED_IN\"):\n",
+ " column_name = \"P1_Name\"\n",
+ " else:\n",
+ " column_name = \"P_Name\"\n",
+ "\n",
+ " results = snapshot.execute_sql(\n",
+ " f\"\"\"\n",
+ " SELECT {column_name} AS Name, *\n",
+ " FROM {table}\n",
+ " WHERE {column_name} = '{kgnode_name}'\n",
+ " \"\"\"\n",
+ " )\n",
+ " for row in results:\n",
+ " # Dynamically determine the index of the relevant column\n",
+ " name_index = row.index(kgnode_name)\n",
+ "\n",
+ " # Construct the output string using f-string\n",
+ " output_string = f\"{row[name_index]} {table} {row[2]}\"\n",
+ " doc = Document(page_content=output_string)\n",
+ " documents.append(doc)\n",
+ " # print (documents)\n",
+ " return documents"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sv_06SAJ8OQH"
+ },
+ "source": [
+ "This chain will use the custom retriever to get the results back from Spanner database and have Gemini to search and return an answer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "id": "nJOQ4d9q0JSS"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.prompts import (\n",
+ " ChatPromptTemplate,\n",
+ " HumanMessagePromptTemplate,\n",
+ " SystemMessagePromptTemplate,\n",
+ ")\n",
+ "from langchain_core.output_parsers import StrOutputParser\n",
+ "from langchain_core.runnables import RunnablePassthrough\n",
+ "from langchain_google_vertexai import ChatVertexAI\n",
+ "import vertexai\n",
+ "\n",
+ "\n",
+ "class LangchainGraphRag:\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name,\n",
+ " embedding_model_name,\n",
+ " gcp_project_id,\n",
+ " region,\n",
+ " spanner_instance_id,\n",
+ " spanner_database_id,\n",
+ " ):\n",
+ "\n",
+ " self.model_name = model_name\n",
+ " self.embedding_model_name = embedding_model_name\n",
+ " self.spanner_instance_id = spanner_instance_id\n",
+ " self.spanner_database_id = spanner_database_id\n",
+ " self.max_output_tokens = 1024\n",
+ " self.temperature = 0.1\n",
+ " self.top_p = 0.8\n",
+ " self.top_k = 10\n",
+ " self.gcp_project_id = gcp_project_id\n",
+ " self.region = region\n",
+ " self.prompt_template = \"\"\"\n",
+ " You are trivia expert specializing in famous people and their achievements, able to provide insightful answers based solely on the information provided in the context.\n",
+ " Context: {context}\"\"\"\n",
+ "\n",
+ " def configure_qa_rag_chain(self, llm):\n",
+ " qa_prompt = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " SystemMessagePromptTemplate.from_template(self.prompt_template),\n",
+ " HumanMessagePromptTemplate.from_template(\"Question: {question}\"),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " retriever = SpannerGraphRagRetriever(\n",
+ " gcp_project_id=self.gcp_project_id,\n",
+ " spanner_instance_id=self.spanner_instance_id,\n",
+ " spanner_database_id=self.spanner_database_id,\n",
+ " embedding_model_name=self.embedding_model_name,\n",
+ " )\n",
+ "\n",
+ " def format_docs(docs):\n",
+ " return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+ "\n",
+ " chain = (\n",
+ " {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
+ " | qa_prompt\n",
+ " | llm\n",
+ " | StrOutputParser()\n",
+ " )\n",
+ " return chain\n",
+ "\n",
+ " def set_up(self):\n",
+ "\n",
+ " # init the vertexai package\n",
+ " vertexai.init(project=self.gcp_project_id, location=self.region)\n",
+ " llm = ChatVertexAI(\n",
+ " model_name=self.model_name,\n",
+ " max_output_tokens=self.max_output_tokens,\n",
+ " max_input_tokens=32000,\n",
+ " temperature=self.temperature,\n",
+ " top_p=self.top_p,\n",
+ " top_k=self.top_k,\n",
+ " project=self.gcp_project_id,\n",
+ " location=self.region,\n",
+ " # convert_system_message_to_human=True,\n",
+ " response_validation=False,\n",
+ " verbose=True,\n",
+ " )\n",
+ "\n",
+ " self.qa_chain = self.configure_qa_rag_chain(llm)\n",
+ "\n",
+ " def query(self, query):\n",
+ " return self.qa_chain.invoke(query)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B3o2czmN8Goo"
+ },
+ "source": [
+ "We will create the chain and test it out."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "id": "QhxJy4omzP9J"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Maria Montessori.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "lc = LangchainGraphRag(\n",
+ " MODEL_NAME,\n",
+ " EMBEDDING_MODEL_NAME,\n",
+ " GCP_PROJECT_ID,\n",
+ " REGION,\n",
+ " SPANNER_INSTANCE_ID,\n",
+ " SPANNER_DATABASE_ID,\n",
+ ")\n",
+ "lc.set_up()\n",
+ "\n",
+ "response = lc.query(\"Who influenced Larry Page?\")\n",
+ "print(response)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "50jswr0X8cOK"
+ },
+ "source": [
+ "### Deploy to Vertex AI Reasoning Engine\n",
+ "\n",
+ "Great! The chain is working and return the answers from the orginal question. Now let's host this LangChain code on Vertex AI Reasoning Engine so we can call it as an API. First we will reorganize the code blocks above to make it work with the Vertex AI Reasoning Engine. Be sure to update the code with your Project ID and Region"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "id": "e-NgQrA76MRg"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain.prompts import (\n",
+ " ChatPromptTemplate,\n",
+ " HumanMessagePromptTemplate,\n",
+ " SystemMessagePromptTemplate,\n",
+ ")\n",
+ "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
+ "from langchain_core.documents import Document\n",
+ "from langchain_core.output_parsers import StrOutputParser\n",
+ "from langchain_core.retrievers import BaseRetriever\n",
+ "from langchain_core.runnables import RunnablePassthrough\n",
+ "from langchain_google_vertexai import ChatVertexAI\n",
+ "import vertexai\n",
+ "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
+ "\n",
+ "\n",
+ "class LangchainGraphRag:\n",
+ "\n",
+ " def __init__(self) -> None:\n",
+ "\n",
+ " self.model_name = \"gemini-1.5-pro-002\"\n",
+ " self.gcp_project_id = \"\"\n",
+ " self.region = \"us-central1\"\n",
+ " self.max_output_tokens = 1024\n",
+ " self.temperature = 0.1\n",
+ " self.top_p = 0.8\n",
+ " self.top_k = 10\n",
+ " self.prompt_template = \"\"\"\n",
+ " You are trivia expert specializing in famous people and their achievements, able to provide insightful answers based solely on the information provided in the context.\n",
+ " Context: {context}\"\"\"\n",
+ "\n",
+ " def configure_qa_rag_chain(self, llm):\n",
+ " qa_prompt = ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " SystemMessagePromptTemplate.from_template(self.prompt_template),\n",
+ " HumanMessagePromptTemplate.from_template(\"Question: {question}\"),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " retriever = self.SpannerGraphRagRetriever()\n",
+ "\n",
+ " def format_docs(docs):\n",
+ " return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+ "\n",
+ " chain = (\n",
+ " {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
+ " | qa_prompt\n",
+ " | llm\n",
+ " | StrOutputParser()\n",
+ " )\n",
+ " return chain\n",
+ "\n",
+ " def set_up(self) -> None:\n",
+ " print(\"set_up\")\n",
+ " # init the vertexai package\n",
+ " vertexai.init(project=self.gcp_project_id, location=self.region)\n",
+ " llm = ChatVertexAI(\n",
+ " model_name=self.model_name,\n",
+ " max_output_tokens=self.max_output_tokens,\n",
+ " max_input_tokens=32000,\n",
+ " temperature=self.temperature,\n",
+ " top_p=self.top_p,\n",
+ " top_k=self.top_k,\n",
+ " project=self.gcp_project_id,\n",
+ " location=self.region,\n",
+ " # convert_system_message_to_human=True,\n",
+ " response_validation=False,\n",
+ " verbose=True,\n",
+ " )\n",
+ "\n",
+ " self.qa_chain = self.configure_qa_rag_chain(llm)\n",
+ "\n",
+ " def query(self, query):\n",
+ " return self.qa_chain.invoke(query)\n",
+ "\n",
+ " class SpannerGraphRagRetriever(BaseRetriever):\n",
+ "\n",
+ " def get_embedding(text, task_type, model):\n",
+ " try:\n",
+ " text_embedding_input = TextEmbeddingInput(\n",
+ " task_type=task_type, text=text\n",
+ " )\n",
+ " embeddings = model.get_embeddings([text_embedding_input])\n",
+ " return embeddings[0].values\n",
+ " except:\n",
+ " return []\n",
+ "\n",
+ " def _get_relevant_documents(\n",
+ " self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
+ " ) -> list[Document]:\n",
+ " model_name = \"gemini-1.5-pro-002\"\n",
+ " embedding_model_name = \"text-embedding-004\"\n",
+ " spanner_instance_id = \"graphrag-instance\"\n",
+ " spanner_database_id = \"graphrag\"\n",
+ " gcp_project_id = \"\"\n",
+ " region = \"us-central1\"\n",
+ " # init the vertexai package\n",
+ " vertexai.init(project=gcp_project_id, location=region)\n",
+ " # Query the Spanner database and construct the LangChain Documents\n",
+ " spanner_client = spanner.Client(gcp_project_id)\n",
+ " instance = spanner_client.instance(spanner_instance_id)\n",
+ " database = instance.database(spanner_database_id)\n",
+ " embedding_model = TextEmbeddingModel.from_pretrained(embedding_model_name)\n",
+ " task_type = \"QUESTION_ANSWERING\"\n",
+ " kgnode_name = \"\"\n",
+ "\n",
+ " q_emb = get_embedding(query, task_type, embedding_model)\n",
+ "\n",
+ " with database.snapshot() as snapshot:\n",
+ " results = snapshot.execute_sql(\n",
+ " \"\"\"SELECT DocId, NAME, Doc FROM KgNode ORDER BY COSINE_DISTANCE(DocEmbedding, @q_emb) limit 1\"\"\",\n",
+ " params={\"q_emb\": q_emb},\n",
+ " param_types={\n",
+ " \"q_emb\": spanner.param_types.Array(spanner.param_types.FLOAT64)\n",
+ " }, # Adjust FLOAT64 if needed\n",
+ " )\n",
+ " for row in results:\n",
+ " kgnode_name = row[1]\n",
+ " print(\"*************knode name \" + kgnode_name)\n",
+ " tables = [\"NATIONALITY\", \"WORKED_AT\", \"INVESTMENT\", \"INFLUENCED_BY\"]\n",
+ " documents = []\n",
+ " for table in tables:\n",
+ " with database.snapshot() as snapshot:\n",
+ " if table in (\"INFLUENCED_BY\", \"LOCATED_IN\"):\n",
+ " column_name = \"P1_Name\"\n",
+ " else:\n",
+ " column_name = \"P_Name\"\n",
+ "\n",
+ " results = snapshot.execute_sql(\n",
+ " f\"\"\"\n",
+ " SELECT {column_name} AS Name, *\n",
+ " FROM {table}\n",
+ " WHERE {column_name} = '{kgnode_name}'\n",
+ " \"\"\"\n",
+ " )\n",
+ " for row in results:\n",
+ " # Dynamically determine the index of the relevant column\n",
+ " name_index = row.index(kgnode_name)\n",
+ "\n",
+ " # Construct the output string using f-string\n",
+ " output_string = f\"{row[name_index]} {table} {row[2]}\"\n",
+ " a.append\n",
+ " print(output_string)\n",
+ " doc = Document(page_content=output_string)\n",
+ " documents.append(doc)\n",
+ " return documents"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "id": "AAItAzMKbDmt"
+ },
+ "outputs": [],
+ "source": [
+ "import vertexai\n",
+ "from vertexai.preview import reasoning_engines\n",
+ "\n",
+ "vertexai.init(\n",
+ " project=GCP_PROJECT_ID,\n",
+ " location=REGION,\n",
+ " staging_bucket=STAGING_BUCKET,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WtIW59QW9F-z"
+ },
+ "source": [
+ "Give service-PROJECT_NUMBER@gcp-sa-aiplatform-re.iam.gserviceaccount.com Cloud Spanner Database User role or Vertex AI Reasoning Engine will not have access to the Spanner database"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "rJlLpI09d9Kd"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud projects add-iam-policy-binding {GCP_PROJECT_ID} \\\n",
+ " --member='serviceAccount:service-{GCP_PROJECT_NUMBER}@gcp-sa-aiplatform-re.iam.gserviceaccount.com' \\\n",
+ " --role='roles/spanner.databaseUser'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "y4BBUoAK9Oz6"
+ },
+ "source": [
+ "Deploy the package"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DO8Ts9iFcJrj"
+ },
+ "outputs": [],
+ "source": [
+ "remote_app = reasoning_engines.ReasoningEngine.create(\n",
+ " LangchainGraphRag(),\n",
+ " requirements=[\n",
+ " \"google-cloud-aiplatform==1.67.0\",\n",
+ " \"google-cloud-resource-manager==1.12.5\",\n",
+ " \"google-cloud-spanner==3.48.0\",\n",
+ " \"langchain==0.3.0\",\n",
+ " \"langchain-community==0.3.0\",\n",
+ " \"langchain-core==0.3.2\",\n",
+ " \"langchain-experimental==0.3.0\",\n",
+ " \"langchain-google-vertexai==2.0.1\",\n",
+ " \"langchain-text-splitters==0.3.0\",\n",
+ " \"cloudpickle==3.1.0\",\n",
+ " ],\n",
+ " display_name=\"GraphRAG Demo\",\n",
+ " description=\"GraphRAG Demo\",\n",
+ " extra_packages=[],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "IcE2ferJ9RyB"
+ },
+ "source": [
+ "Now we can query the Vertex AI Reason Engine with the original question."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "id": "_iBOGHGz7XlA"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Maria Montessori.\\n'"
+ ]
+ },
+ "execution_count": 42,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "remote_app.query(query=\"Who influenced Larry Page?\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vaY6vhFE9egF"
+ },
+ "source": [
+ "### Clean Up\n",
+ "\n",
+ "* Delete the Vertex AI Reasoning Engine Endpoint\n",
+ "* Delete the Spanner instance\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DA2nh3QW7YcK"
+ },
+ "outputs": [],
+ "source": [
+ "remote_app.delete()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "id": "Qb4FcKf9dS-d"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud spanner instances delete {SPANNER_INSTANCE_ID} --quiet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1120dd06a83f"
+ },
+ "source": [
+ "## What's next\n",
+ "\n",
+ "* Dive deeper into [LangChain on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/reasoning-engine/overview).\n",
+ "* Learn more about [Google Cloud Spanner](https://cloud.google.com/spanner/docs/getting-started/python).\n",
+ "* Explore other [Reasoning Engine samples](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/reasoning-engine)."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "name": "tutorial_graph_rag.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}