Skip to content

Commit

Permalink
Semantic search plugin support
Browse files Browse the repository at this point in the history
  • Loading branch information
toluaina committed Feb 24, 2024
1 parent a74c19f commit 513a296
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 62 deletions.
10 changes: 5 additions & 5 deletions bin/es_mapping
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import click
from elasticsearch import Elasticsearch, helpers

from pgsync.settings import ELASTICSEARCH_TIMEOUT, ELASTICSEARCH_VERIFY_CERTS
from pgsync.urls import get_elasticsearch_url
from pgsync.urls import get_search_url
from pgsync.utils import config_loader, get_config, timeit

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,7 +53,7 @@ def apply_analyzer_to_mapping(mapping: dict, analyzer: dict) -> dict:


def get_configuration(es, index: str) -> dict:
configuration: dict = es.indices.get_settings(index)[index]
configuration: dict = es.indices.get_settings(index=index)[index]
# skip these attributes
for key in [
"uuid",
Expand All @@ -62,7 +62,7 @@ def get_configuration(es, index: str) -> dict:
"creation_date",
]:
configuration["settings"]["index"].pop(key)
mapping: dict = es.indices.get_mapping(index)
mapping: dict = es.indices.get_mapping(index=index)
analyzer_mapping: dict = apply_analyzer_to_mapping(
mapping,
{
Expand All @@ -80,10 +80,10 @@ def get_configuration(es, index: str) -> dict:
@timeit
def create_es_mapping(index: str) -> None:
logger.debug(f"Create Elasticsearch mapping for index {index}")
url: str = get_elasticsearch_url()
url: str = get_search_url()
es: Elasticsearch = Elasticsearch(
hosts=[url],
timeout=ELASTICSEARCH_TIMEOUT,
request_timeout=ELASTICSEARCH_TIMEOUT,
verify_certs=ELASTICSEARCH_VERIFY_CERTS,
)

Expand Down
12 changes: 12 additions & 0 deletions demo/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
To run the PGSync demo:

- make sure you have bootstrapped your project and populated or Elasticsearch/Opensearch
- `cd pgsync`
- `source source <path/to/virtualenv>/bin/activate`
- `source .pythonpath`
- `./bin/es_mapping -c /path/to/your/schema.json`
- `cd demo`
- `pip install -r requirements.txt`
- `./runserver.sh`

Then open your browser and go to `http://127.0.0.1:5000/`
5 changes: 3 additions & 2 deletions demo/app/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""PGSync Demo application."""

import aiohttp_cors
from aiohttp import web
from app import settings
from app.views import TypeAheadView
from app.views import TypeAheadHandler, TypeAheadView


async def create_app():
Expand All @@ -14,10 +15,10 @@ async def create_app():
)
app.add_routes(
[
web.get("/typeahead", TypeAheadHandler),
web.get("/", TypeAheadView),
]
)

cors = aiohttp_cors.setup(
app,
defaults={
Expand Down
2 changes: 2 additions & 0 deletions demo/app/settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""PGSync Demo settings."""

from environs import Env

env = Env()

env.read_env()

MAX_RESULTS = env.int("MAX_RESULTS", default=100)
VECTOR_SEARCH = env.bool("VECTOR_SEARCH", default=False)

ELASTICSEARCH_URL = env.str("ELASTICSEARCH_URL")
ELASTICSEARCH_INDEX = env.str("ELASTICSEARCH_INDEX")
Expand Down
9 changes: 9 additions & 0 deletions demo/app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from functools import lru_cache

from plugins.openai_plugin import OpenAIPlugin


@lru_cache
def get_embedding(text: str) -> list:
plugin: OpenAIPlugin = OpenAIPlugin()
return plugin.get_embedding(text)
56 changes: 40 additions & 16 deletions demo/app/views.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""PGSync Demo views."""

import logging
from copy import deepcopy

from aiohttp import web
from app.settings import (
ELASTICSEARCH_INDEX,
ELASTICSEARCH_TIMEOUT,
ELASTICSEARCH_URL,
ELASTICSEARCH_VERIFY_CERTS,
MAX_RESULTS,
VECTOR_SEARCH,
)
from elasticsearch import Elasticsearch
from elasticsearch_dsl import Search
from elasticsearch_dsl.query import Bool, Match
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.query import Bool, Match, ScriptScore

from .utils import get_embedding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,8 +67,8 @@ def nested_update(self, obj, key, value):
self.nested_update(o, key, value)


class TypeAheadView(web.View):
"""TypeAheadView view."""
class TypeAheadHandler(web.View):
"""TypeAheadHandler view."""

def _build_queries(self, key, value, search_param, parents=[]):
queries = []
Expand Down Expand Up @@ -137,7 +142,7 @@ def build_highlight(self, mapping):
def queries(self, mapping, search_param):
"""Return matching query by search_param."""
should = self.build_queries(
mapping[ELASTICSEARCH_INDEX]["mappings"]["_doc"]["properties"],
mapping[ELASTICSEARCH_INDEX]["mappings"]["properties"],
search_param.lower(),
)
return Bool(
Expand All @@ -147,27 +152,38 @@ def queries(self, mapping, search_param):

async def get(self):
"""Get the results from Elasticsearch."""
q = self.request.query.get("q")
if not q:
qs: str = self.request.query.get("q")
if not qs:
return web.json_response([])

es = Elasticsearch(
hosts=[self.request.app["settings"].ELASTICSEARCH_URL],
es: Elasticsearch = Elasticsearch(
hosts=ELASTICSEARCH_URL,
timeout=ELASTICSEARCH_TIMEOUT,
verify_certs=ELASTICSEARCH_VERIFY_CERTS,
)
mapping = es.indices.get_mapping(
ELASTICSEARCH_INDEX, include_type_name=True
)
search = Search(index=ELASTICSEARCH_INDEX, using=es)
mapping: dict = es.indices.get_mapping(index=ELASTICSEARCH_INDEX)
search: Search = Search(index=ELASTICSEARCH_INDEX, using=es)
search = search.highlight_options(
pre_tags=[PRE_HIGHLIGHT_TAG],
post_tags=[POST_HIGHLIGHT_TAG],
)
query = self.queries(mapping, q)
search = search.query(query)

if VECTOR_SEARCH:
search = search.query(
ScriptScore(
query=Q("match_all"),
script={
"source": "cosineSimilarity(params.query_vector, \u0027embedding\u0027) + 1.0",
"params": {"query_vector": get_embedding(qs)},
},
)
)
else:
query = self.queries(mapping, qs)
search = search.query(query)

highlights = self.build_highlight(
mapping[ELASTICSEARCH_INDEX]["mappings"]["_doc"]["properties"]
mapping[ELASTICSEARCH_INDEX]["mappings"]["properties"]
)

for highlight in highlights:
Expand All @@ -176,6 +192,7 @@ async def get(self):
search = search.extra(
from_=0,
size=MAX_RESULTS,
_source={"excludes": ["embedding"]},
)

values = []
Expand All @@ -192,3 +209,10 @@ async def get(self):
else:
values.append(hit._d_)
return web.json_response(values)


class TypeAheadView(web.View):
"""TypeAheadView view."""

async def get(self):
return web.FileResponse("index.html")
57 changes: 34 additions & 23 deletions demo/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
font-family: helvetica, sans-serif;
line-height: 1.4;
}

.card {
background: #fff;
border-radius: 2px;
display: inline-block;
height: 300px;
margin: 1rem;
position: relative;
width: 300px;
background: #fff;
border-radius: 2px;
display: inline-block;
height: 300px;
margin: 1rem;
position: relative;
width: 300px;
}

h2 {
color: white;
text-align:center;
text-align: center;
display: block;
font-size: 1.5em;
margin-top: 1.83em;
Expand All @@ -30,30 +32,39 @@
font-weight: bold;
}
</style>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700|Material+Icons">
<link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/css/bootstrap-material-design.min.css" integrity="sha384-wXznGJNEXNG1NFsbm0ugrLFMQPWswR3lds2VeinahP8N0zJw9VWSopbjv2x7WCvX" crossorigin="anonymous">
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700|Material+Icons">
<link rel="stylesheet"
href="https://unpkg.com/[email protected]/dist/css/bootstrap-material-design.min.css"
integrity="sha384-wXznGJNEXNG1NFsbm0ugrLFMQPWswR3lds2VeinahP8N0zJw9VWSopbjv2x7WCvX" crossorigin="anonymous">
</head>

<body>
<h1>PGSync typeahead demo</h1>

<div class="bmd-form-group bmd-collapse-inline pull-xs-right" style="top: 10%; left: 10%; position: absolute;">
<button class="btn bmd-btn-icon" for="search" data-toggle="collapse" data-target="#collapse-search" aria-expanded="false" aria-controls="collapse-search">
<i class="material-icons">search</i>
</button>
<span id="collapse-search" class="collapse">
<input class="form-control" type="search" id="searchbox" placeholder="Enter your query...">
</span>
<button class="btn bmd-btn-icon" for="search" data-toggle="collapse" data-target="#collapse-search"
aria-expanded="false" aria-controls="collapse-search">
<i class="material-icons">search</i>
</button>
<span id="collapse-search" class="collapse">
<input class="form-control" type="search" id="searchbox" placeholder="Enter your query...">
</span>
</div>

<div id="results" style="top: 20%; left: 10%; position: absolute;"></div>
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js" integrity="sha384-KJ3o2DKtIkvYIK3UENzmM7KCkRr/rE9/Qpg6aAZGJwFDMVNA/GpGFF93hXpG5KkN" crossorigin="anonymous"></script>
<script src="https://unpkg.com/[email protected]/dist/umd/popper.js" integrity="sha384-fA23ZRQ3G/J53mElWqVJEGJzU0sTs+SvzG8fXVWP+kJQ1lwFAOkcUOysnlKJC33U" crossorigin="anonymous"></script>
<script src="https://unpkg.com/[email protected]/dist/js/bootstrap-material-design.js" integrity="sha384-CauSuKpEqAFajSpkdjv3z9t8E7RlpJ1UP0lKM/+NdtSarroVKu069AlsRPKkFBz9" crossorigin="anonymous"></script>
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js"
integrity="sha384-KJ3o2DKtIkvYIK3UENzmM7KCkRr/rE9/Qpg6aAZGJwFDMVNA/GpGFF93hXpG5KkN"
crossorigin="anonymous"></script>
<script src="https://unpkg.com/[email protected]/dist/umd/popper.js"
integrity="sha384-fA23ZRQ3G/J53mElWqVJEGJzU0sTs+SvzG8fXVWP+kJQ1lwFAOkcUOysnlKJC33U"
crossorigin="anonymous"></script>
<script src="https://unpkg.com/[email protected]/dist/js/bootstrap-material-design.js"
integrity="sha384-CauSuKpEqAFajSpkdjv3z9t8E7RlpJ1UP0lKM/+NdtSarroVKu069AlsRPKkFBz9"
crossorigin="anonymous"></script>

<script>

const base_url = 'http://localhost:8000/';
const base_url = 'http://localhost:5000/typeahead';
const searchbox = document.getElementById("searchbox");

let requestInFlight = null;
Expand All @@ -71,8 +82,8 @@ <h1>PGSync typeahead demo</h1>
// Avoid race conditions where a slow request returns after a faster one.
return;
}
let results = '<div class="container-fluid">' +
d.map(result => `
let results = '<div class="container-fluid">' +
d.map(result => `
<pre style= "color:gray; font-size: 15px">
${JSON.stringify(result, null, 4)}
</pre><hr>
Expand All @@ -83,7 +94,7 @@ <h1>PGSync typeahead demo</h1>

function debounce(func, wait, immediate) {
let timeout;
return function() {
return function () {
let context = this,
args = arguments;
let later = () => {
Expand Down
2 changes: 1 addition & 1 deletion demo/runserver.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#! /bin/sh
adev runserver -p 8000 server.py
adev runserver -p 5000 server.py
6 changes: 5 additions & 1 deletion examples/book/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"mapping": {
"authors": {
"type": "nested"
},
"embedding": {
"type": "dense_vector",
"dims": 1536
}
},
"setting": {
Expand Down Expand Up @@ -33,7 +37,7 @@
}
}
},
"plugins": ["Groot", "Hero", "Villain", "Geometry", "Infinity"],
"plugins": ["Groot", "Hero", "Villain", "Geometry", "Infinity", "TextEmbedding3Small"],
"nodes": {
"table": "book",
"columns": [
Expand Down
1 change: 1 addition & 0 deletions pgsync/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"constant_keyword",
"date",
"date_range",
"dense_vector",
"double",
"double_range",
"flattened",
Expand Down
41 changes: 41 additions & 0 deletions plugins/openai_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from functools import lru_cache

from openai import OpenAI

from pgsync import plugin


class OpenAIPlugin(plugin.Plugin):
"""
I am an OpenAI plugin.
I generate embeddings for documents using openai's text-embedding-3-small model.
`pip install openai`
"""

def __init__(self) -> None:
super().__init__()
self.client: OpenAI = OpenAI()
self.model: str = "text-embedding-3-small"
# vector dims must match models input dims
self.vector_dims = 1536

name: str = "TextEmbedding3Small"

@lru_cache
def get_embedding(self, text: str) -> list:
text: str = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.model)
.data[0]
.embedding
)

def transform(self, doc: dict, **kwargs) -> dict:
"""Demonstrates how to generate openai embeddings and add them to the document"""
fields = doc["book_title"]
embedding: list = self.get_embedding(fields)
if len(embedding) != self.vector_dims:
raise ValueError(f"Embedding dims != {self.vector_dims}.")

doc["embedding"] = embedding
return doc
Loading

0 comments on commit 513a296

Please sign in to comment.