Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Dec 2, 2024
1 parent 75f48cf commit 5b4ef62
Show file tree
Hide file tree
Showing 18 changed files with 737 additions and 5 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,10 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_Store
README.p.md
ai-search-demo/.DS_Store
ai-search-demo/storage/
ai-search-demo/example_data/
ml-stages/data/
ml-stages/script/
15 changes: 14 additions & 1 deletion ai-search-demo/README.en.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ python ai_search_demo/evaluate_synthetic_data.py create-synthetic-dataset ./exam
python ai_search_demo/evaluate_synthetic_data.py evaluate-on-synthetic-dataset koml/smart-hr-synthetic-data-single-image-multiple-queries --collection-name smart-hr-synthetic-data-single-image-multiple-queries
```

## Demo data

- [SmartHR](https://smarthr.jp/know-how/ebook/tv-campaign/)
- VC Reports: [InfraRed](https://www.redpoint.com/infrared/report/) & [024: The State of Generative AI in the Enterprise](https://menlovc.com/2024-the-state-of-generative-ai-in-the-enterprise/)

## Architecture

High-level diagram of the system.
Expand Down Expand Up @@ -135,6 +140,14 @@ sequenceDiagram

## LLM inference


Setup

```
pip install modal
modal setup
```

Download models

```
Expand All @@ -147,6 +160,6 @@ modal run llm-inference/llm_serving_load_models.py --model-name vidore/colqwen2-
Deploy models

```
modal deploy llm-inference/llm_serving.py
modal deploy llm-inference/llm_serving.py
modal deploy llm-inference/llm_serving_colpali.py
```
10 changes: 9 additions & 1 deletion ai-search-demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ sequenceDiagram

## LLM推論

セットアップ

```
pip install modal
modal setup
```


モデルのダウンロード

```
Expand All @@ -146,6 +154,6 @@ modal run llm-inference/llm_serving_load_models.py --model-name vidore/colqwen2-
モデルをデプロイする

```
modal deploy llm-inference/llm_serving.py
modal deploy llm-inference/llm_serving.py
modal deploy llm-inference/llm_serving_colpali.py
```
2 changes: 1 addition & 1 deletion ai-search-demo/ai_search_demo/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def ai_search():

# Display VLLM output in the second column
with col2:
st.markdown("<h3 style='color:green;'>Relevant Images</h3>", unsafe_allow_html=True)
st.markdown("<h3 style='color:green;'>Interpretation with LLM</h3>", unsafe_allow_html=True)
for image_data, _, _, _ in search_results_data:
with st.spinner("Processing with VLLM..."):
vllm_output = call_vllm(image_data)
Expand Down
1 change: 1 addition & 0 deletions ai-search-demo/llm-inference/llm_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
@app.function(
image=vllm_image,
gpu=modal.gpu.H100(count=N_GPU),
keep_warm=1,
container_idle_timeout=5 * MINUTES,
timeout=24 * HOURS,
allow_concurrent_inputs=1000,
Expand Down
1 change: 1 addition & 0 deletions ai-search-demo/llm-inference/llm_serving_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@app.function(
image=vllm_image,
gpu=modal.gpu.H100(count=N_GPU),
keep_warm=1,
container_idle_timeout=5 * MINUTES,
timeout=24 * HOURS,
allow_concurrent_inputs=1000,
Expand Down
2 changes: 0 additions & 2 deletions ai-search-demo/llm-inference/llm_serving_load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
DEFAULT_NAME = "Qwen/Qwen2.5-7B-Instruct"
DEFAULT_REVISION = "bb46c15ee4bb56c5b63245ef50fd7637234d6f75"

# Qwen/Qwen2-VL-7B-Instruct


volume = modal.Volume.from_name("models", create_if_missing=True)

Expand Down
14 changes: 14 additions & 0 deletions ml-stages/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel
ARG DEBIAN_FRONTEND=noninteractive

ENV TZ=America/Los_Angeles

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
ffmpeg \
git \
git-lfs

RUN pip install --upgrade pip
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
4 changes: 4 additions & 0 deletions ml-stages/README.en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# ML Stages

![alt text](./docs/end2end.jpg)
Based on this design document: [SmartHR LLM Workshop, December 2nd, 2024](https://docs.google.com/document/d/1xvJiWK8nvtC8Ek34X5Uk1BCCa_GvYqej-8tjk8o8obE/edit?tab=t.0)
4 changes: 4 additions & 0 deletions ml-stages/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# 機械学習のステージ

![alt text](./docs/end2end.jpg)
この設計文書に基づいています: [SmartHR LLM ワークショップ、2024年12月2日](https://docs.google.com/document/d/1xvJiWK8nvtC8Ek34X5Uk1BCCa_GvYqej-8tjk8o8obE/edit?tab=t.0)
98 changes: 98 additions & 0 deletions ml-stages/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import streamlit as st
import duckdb
from openai import OpenAI


from pydantic import BaseModel
from traceloop.sdk import Traceloop
from traceloop.sdk.decorators import workflow

Traceloop.init(app_name="duckdb-project", disable_batch=True)

class Text2SQLSample(BaseModel):
query: str

# @observe
@workflow(name="text2sql")
def text2sql(user_prompt: str, schema: str, table_name: str) -> str:
model: str = "gpt-4o"
client = OpenAI()

prompt = f"""
Write the corresponding SQL query based on user prompt and database schema:
- user prompt: {user_prompt}
- database schema: {schema}
Return only JSON.
Table name is {table_name}
"""

chat_completion = client.beta.chat.completions.parse(
messages=[
{
"role": "system",
"content": "You are DuckDB and SQL expert.",
},
{
"role": "user",
"content": prompt,
},
],
model=model,
response_format=Text2SQLSample,
temperature=1,
)
query = chat_completion.choices[0].message.parsed.query
return query

# Initialize DuckDB connection
conn = duckdb.connect()

# Load the httpfs extension for hf:// paths
conn.execute("INSTALL httpfs;")
conn.execute("LOAD httpfs;")

# Title
st.title("Hugging Face Dataset Query with DuckDB")

# Input for Hugging Face dataset link
hf_link = st.text_input(
"Enter the Hugging Face dataset link (hf://...):",
"hf://datasets/UCSC-VLAA/Recap-DataComp-1B/data/train_data/train-00000-of-02719.parquet",
)

# Input for NLP query
nlp_query = st.text_area("Enter your query in natural language:", "Show me the first 10 rows.")

# Execute query when button is clicked
if st.button("Run Query"):

try:
# Use DESCRIBE to get schema
schema_query = f"DESCRIBE SELECT * FROM '{hf_link}'"
schema_df = conn.execute(schema_query).df()
st.write("### Dataset Schema:")
st.dataframe(schema_df)

# Convert schema_df to string format suitable for text2sql
schema_str = "\n".join([f"{row['column_name']} ({row['column_type']})" for index, row in schema_df.iterrows()])

except Exception as e:
st.error(f"An error occurred while fetching schema: {e}")
schema_str = "" # Ensure schema_str is defined

try:
# Convert NLP query to SQL using text2sql
sql_query = text2sql(nlp_query, schema_str, table_name=hf_link)
st.write("### Generated SQL Query:")
st.code(sql_query, language='sql')

# sql_query = sql_query.replace("table_name", hf_link)
# Execute the SQL query
df = conn.execute(sql_query).df()
# Display the results
st.write("### Query Results:")
st.dataframe(df)
except Exception as e:
st.error(f"An error occurred: {e}")
140 changes: 140 additions & 0 deletions ml-stages/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Dict, List

import argilla as rg
import typer
from datasets import Dataset, load_dataset
from openai import OpenAI
from pydantic import BaseModel
from retry import retry
from tqdm import tqdm


class Text2SQLSample(BaseModel):
prompt: str
schema: str
query: str

client = rg.Argilla(api_url="http://0.0.0.0:6900", api_key="argilla.apikey")
# argilla / 12345678 / argilla.apikey

def create_text2sql_dataset(dataset_name: str, data_samples: Dataset | List[Dict]):
guidelines = """
Please examine the given DuckDB SQL question and context.
Write the correct DuckDB SQL query that accurately answers the question based on the context provided.
Ensure the query follows DuckDB SQL syntax and logic correctly.
"""

settings = rg.Settings(
guidelines=guidelines,
fields=[
rg.TextField(
name="prompt",
title="Prompt",
use_markdown=False,
),
rg.TextField(
name="schema",
title="Schema",
use_markdown=True,
),
rg.TextField(
name="query",
title="Query",
use_markdown=True,
),
],
questions=[
rg.TextQuestion(
name="sql",
title="Please write SQL for this query",
description="Please write SQL for this query",
required=True,
use_markdown=True,
)
],
)

if 'admin' not in [x.name for x in client.workspaces.list()]:
workspace = rg.Workspace(name="admin")
workspace.create()

dataset = rg.Dataset(
name=dataset_name,
workspace="admin",
settings=settings,
client=client,
)
dataset.create()
records = []
for idx in range(len(data_samples)):
x = rg.Record(
fields={
"prompt": data_samples[idx]["prompt"],
"schema": data_samples[idx]["schema"],
"query": data_samples[idx]["query"],
},
)
records.append(x)
dataset.records.log(records, batch_size=1000)

def upload_duckdb_text2sql():
dataset_name = "motherduckdb/duckdb-text2sql-25k"
raw_dataset = load_dataset(dataset_name, split="train")
raw_datasets = raw_dataset.train_test_split(test_size=0.05, seed=42)

raw_dataset.to_json(path_or_buf='./data/duckdb-text2sql-25k.json')
raw_datasets['train'].to_json(path_or_buf='./data/train.json')
raw_datasets['test'].to_json(path_or_buf='./data/test.json')

create_text2sql_dataset(dataset_name='duckdb-text2sql-train', data_samples=raw_datasets['train'].to_list())
create_text2sql_dataset(dataset_name='duckdb-text2sql-test', data_samples=raw_datasets['test'].to_list())


@retry(tries=3, delay=1)
def generate_synthetic_example() -> Dict[str, str]:
client = OpenAI()

prompt = """
Generate a example for text2sql task for DuckDB database:
The example should include
- schema: a valid database schema
- prompt: a typical user question related to this table prompt
- query: the corresponding SQL query to answer user prompt.
Return only JSON. Use Japanese language for prompt (user question).
"""

chat_completion = client.beta.chat.completions.parse(
messages=[
{
"role": "system",
"content": "You are DuckDB and SQL expert.",
},
{
"role": "user",
"content": prompt,
},
],
model="gpt-4o",
response_format=Text2SQLSample,
temperature=1,
)
sample = chat_completion.choices[0].message.parsed
return sample.model_dump()


def create_text2sql_dataset_synthetic(num_samples: int = 10):


samples = []
for _ in tqdm(range(num_samples)):
sample = generate_synthetic_example()
samples.append(sample)

dataset_name = "duckdb-text2sql-synthetic-jp"
create_text2sql_dataset(dataset_name=dataset_name, data_samples=samples)

if __name__ == "__main__":
app = typer.Typer()
app.command()(upload_duckdb_text2sql)
app.command()(create_text2sql_dataset_synthetic)
app()
Binary file added ml-stages/docs/end2end.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 5b4ef62

Please sign in to comment.