Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Dec 1, 2024
1 parent 5074193 commit 4d69ae1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
2 changes: 1 addition & 1 deletion ai-search-demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z

```
docker build -t smart-hr-ai-search:latest .
docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
docker run -p 8000:8000 -v $(pwd)/app_storage:/storage smart-hr-ai-search:latest
```

## References:
Expand Down
36 changes: 19 additions & 17 deletions ai-search-demo/ai_search_demo/qdrant_inexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@
from pdf2image import convert_from_path
from pypdf import PdfReader
import io






import requests

# Constants
COLPALI_BASE_URL = "https://truskovskiyk--colpali-embedding-serve.modal.run"
COLPALI_TOKEN = "super-secret-token"
QDRANT_URI = "https://qdrant.up.railway.app"
QDRANT_PORT = 443
VECTOR_SIZE = 128
INDEXING_THRESHOLD = 100
QUANTILE = 0.99
TOP_K = 5

class ColPaliClient:
def __init__(self, base_url: str = "https://truskovskiyk--colpali-embedding-serve.modal.run", token: str = "super-secret-token"):
def __init__(self, base_url: str = COLPALI_BASE_URL, token: str = COLPALI_TOKEN):
self.base_url = base_url
self.headers = {"Authorization": f"Bearer {token}"}

Expand Down Expand Up @@ -51,29 +55,27 @@ def process_pil_image(self, pil_image):


class IngestClient:
def __init__(self, qdrant_uri: str = "https://qdrant.up.railway.app"):
self.qdrant_client = QdrantClient(qdrant_uri, port=443, https=True)
def __init__(self, qdrant_uri: str = QDRANT_URI):
self.qdrant_client = QdrantClient(qdrant_uri, port=QDRANT_PORT, https=True)
self.colpali_client = ColPaliClient()

def ingest(self, collection_name, dataset):
vector_size = 128

self.qdrant_client.create_collection(
collection_name=collection_name,
on_disk_payload=True,
optimizers_config=models.OptimizersConfigDiff(
indexing_threshold=100
indexing_threshold=INDEXING_THRESHOLD
),
vectors_config=models.VectorParams(
size=vector_size,
size=VECTOR_SIZE,
distance=models.Distance.COSINE,
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
quantile=0.99,
quantile=QUANTILE,
always_ram=True,
),
),
Expand Down Expand Up @@ -119,11 +121,11 @@ def ingest(self, collection_name, dataset):
print("Indexing complete!")

class SearchClient:
def __init__(self, qdrant_uri: str = "https://qdrant.up.railway.app/"):
self.qdrant_client = QdrantClient(qdrant_uri, port=443, https=True)
def __init__(self, qdrant_uri: str = QDRANT_URI):
self.qdrant_client = QdrantClient(qdrant_uri, port=QDRANT_PORT, https=True)
self.colpali_client = ColPaliClient()

def search_images_by_text(self, query_text, collection_name: str, top_k=5):
def search_images_by_text(self, query_text, collection_name: str, top_k=TOP_K):
# Use ColPaliClient to query text and get the embedding
query_embedding = self.colpali_client.query_text(query_text)

Expand Down
37 changes: 19 additions & 18 deletions ai-search-demo/ai_search_demo/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import os
import json
import pandas as pd
from PIL import Image
import threading
from ai_search_demo.qdrant_inexing import IngestClient, pdfs_to_hf_dataset
from ai_search_demo.qdrant_inexing import SearchClient, IngestClient
from datasets import load_from_disk

STORAGE_DIR = "storage"
COLLECTION_INFO_FILENAME = "collection_info.json"
HF_DATASET_DIRNAME = "hf_dataset"
README_FILENAME = "README.md"
SEARCH_TOP_K = 5

search_client = SearchClient()
ingest_client = IngestClient()

Expand All @@ -18,16 +23,16 @@ def ai_search():
# Input form for user query and collection name
with st.form("search_form"):
user_query = st.text_input("Enter your search query")
if os.path.exists("storage"):
collections = os.listdir("storage")
if os.path.exists(STORAGE_DIR):
collections = os.listdir(STORAGE_DIR)
collection_name = st.selectbox("Select a collection", collections)
else:
st.error("No collections found.")
collection_name = None
search_button = st.form_submit_button("Search")

if search_button and user_query and collection_name:
collection_info_path = os.path.join("storage", collection_name, "collection_info.json")
collection_info_path = os.path.join(STORAGE_DIR, collection_name, COLLECTION_INFO_FILENAME)

if os.path.exists(collection_info_path):
with open(collection_info_path, "r") as json_file:
Expand All @@ -36,22 +41,18 @@ def ai_search():
# Here you would implement the actual search logic
# For now, we just display a placeholder message
st.write(f"Results for query '{user_query}' in collection '{collection_name}':")
search_results = search_client.search_images_by_text(user_query, collection_name=collection_name, top_k=5)
search_results = search_client.search_images_by_text(user_query, collection_name=collection_name, top_k=SEARCH_TOP_K)
if search_results:
dataset_path = os.path.join("storage", collection_name, "hf_dataset")
dataset_path = os.path.join(STORAGE_DIR, collection_name, HF_DATASET_DIRNAME)
dataset = load_from_disk(dataset_path)
for result in search_results.points:
payload = result.payload
score = result.score
image_data = dataset[payload['index']]['image']
pdf_name = dataset[payload['index']]['pdf_name']
pdf_page = dataset[payload['index']]['pdf_page']
page_text = dataset[payload['index']]['page_text']

# Display the extracted information in the UI
st.image(image_data, caption=f"Score: {score}, PDF Name: {pdf_name}, Page: {pdf_page}")
st.write(f"Page Text: {page_text}")

else:
st.write("No results found.")

Expand All @@ -66,7 +67,7 @@ def create_new_collection():

if submit_button and uploaded_files and collection_name:
# Create a directory for the collection
collection_dir = f"storage/{collection_name}"
collection_dir = f"{STORAGE_DIR}/{collection_name}"
os.makedirs(collection_dir, exist_ok=True)

# Save PDFs to the collection directory
Expand All @@ -80,7 +81,7 @@ def create_new_collection():
"status": "processing",
"number_of_PDFs": len(uploaded_files)
}
with open(os.path.join(collection_dir, "collection_info.json"), "w") as json_file:
with open(os.path.join(collection_dir, COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(collection_info, json_file)

st.success(f"Uploaded {len(uploaded_files)} PDFs to collection '{collection_name}'")
Expand All @@ -89,15 +90,15 @@ def create_new_collection():
def process_and_ingest():
# Transform PDFs to HF dataset
dataset = pdfs_to_hf_dataset(collection_dir)
dataset.save_to_disk(os.path.join(collection_dir, "hf_dataset"))
dataset.save_to_disk(os.path.join(collection_dir, HF_DATASET_DIRNAME))

# Ingest collection with IngestClient
ingest_client = IngestClient()
ingest_client.ingest(collection_name, dataset)

# Update JSON status to 'done'
collection_info['status'] = 'done'
with open(os.path.join(collection_dir, "collection_info.json"), "w") as json_file:
with open(os.path.join(collection_dir, COLLECTION_INFO_FILENAME), "w") as json_file:
json.dump(collection_info, json_file)

# Run the processing and ingestion in a separate thread
Expand All @@ -106,12 +107,12 @@ def process_and_ingest():
def display_all_collections():
st.header("Previously Uploaded Collections")

if os.path.exists("storage"):
collections = os.listdir("storage")
if os.path.exists(STORAGE_DIR):
collections = os.listdir(STORAGE_DIR)
collection_data = []

for collection in collections:
collection_info_path = os.path.join("storage", collection, "collection_info.json")
collection_info_path = os.path.join(STORAGE_DIR, collection, COLLECTION_INFO_FILENAME)
if os.path.exists(collection_info_path):
with open(collection_info_path, "r") as json_file:
collection_info = json.load(json_file)
Expand All @@ -126,7 +127,7 @@ def display_all_collections():
st.write("No collections found.")

def about():
with open("README.md", "r") as readme_file:
with open(README_FILENAME, "r") as readme_file:
readme_content = readme_file.read()
st.markdown(readme_content)

Expand Down

0 comments on commit 4d69ae1

Please sign in to comment.