-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
75f48cf
commit 5b4ef62
Showing
18 changed files
with
737 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel | ||
ARG DEBIAN_FRONTEND=noninteractive | ||
|
||
ENV TZ=America/Los_Angeles | ||
|
||
RUN apt-get update && apt-get install -y --no-install-recommends \ | ||
build-essential \ | ||
ffmpeg \ | ||
git \ | ||
git-lfs | ||
|
||
RUN pip install --upgrade pip | ||
COPY requirements.txt requirements.txt | ||
RUN pip install -r requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# ML Stages | ||
|
||
![alt text](./docs/end2end.jpg) | ||
Based on this design document: [SmartHR LLM Workshop, December 2nd, 2024](https://docs.google.com/document/d/1xvJiWK8nvtC8Ek34X5Uk1BCCa_GvYqej-8tjk8o8obE/edit?tab=t.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# 機械学習のステージ | ||
|
||
![alt text](./docs/end2end.jpg) | ||
この設計文書に基づいています: [SmartHR LLM ワークショップ、2024年12月2日](https://docs.google.com/document/d/1xvJiWK8nvtC8Ek34X5Uk1BCCa_GvYqej-8tjk8o8obE/edit?tab=t.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import streamlit as st | ||
import duckdb | ||
from openai import OpenAI | ||
|
||
|
||
from pydantic import BaseModel | ||
from traceloop.sdk import Traceloop | ||
from traceloop.sdk.decorators import workflow | ||
|
||
Traceloop.init(app_name="duckdb-project", disable_batch=True) | ||
|
||
class Text2SQLSample(BaseModel): | ||
query: str | ||
|
||
# @observe | ||
@workflow(name="text2sql") | ||
def text2sql(user_prompt: str, schema: str, table_name: str) -> str: | ||
model: str = "gpt-4o" | ||
client = OpenAI() | ||
|
||
prompt = f""" | ||
Write the corresponding SQL query based on user prompt and database schema: | ||
- user prompt: {user_prompt} | ||
- database schema: {schema} | ||
Return only JSON. | ||
Table name is {table_name} | ||
""" | ||
|
||
chat_completion = client.beta.chat.completions.parse( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are DuckDB and SQL expert.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": prompt, | ||
}, | ||
], | ||
model=model, | ||
response_format=Text2SQLSample, | ||
temperature=1, | ||
) | ||
query = chat_completion.choices[0].message.parsed.query | ||
return query | ||
|
||
# Initialize DuckDB connection | ||
conn = duckdb.connect() | ||
|
||
# Load the httpfs extension for hf:// paths | ||
conn.execute("INSTALL httpfs;") | ||
conn.execute("LOAD httpfs;") | ||
|
||
# Title | ||
st.title("Hugging Face Dataset Query with DuckDB") | ||
|
||
# Input for Hugging Face dataset link | ||
hf_link = st.text_input( | ||
"Enter the Hugging Face dataset link (hf://...):", | ||
"hf://datasets/UCSC-VLAA/Recap-DataComp-1B/data/train_data/train-00000-of-02719.parquet", | ||
) | ||
|
||
# Input for NLP query | ||
nlp_query = st.text_area("Enter your query in natural language:", "Show me the first 10 rows.") | ||
|
||
# Execute query when button is clicked | ||
if st.button("Run Query"): | ||
|
||
try: | ||
# Use DESCRIBE to get schema | ||
schema_query = f"DESCRIBE SELECT * FROM '{hf_link}'" | ||
schema_df = conn.execute(schema_query).df() | ||
st.write("### Dataset Schema:") | ||
st.dataframe(schema_df) | ||
|
||
# Convert schema_df to string format suitable for text2sql | ||
schema_str = "\n".join([f"{row['column_name']} ({row['column_type']})" for index, row in schema_df.iterrows()]) | ||
|
||
except Exception as e: | ||
st.error(f"An error occurred while fetching schema: {e}") | ||
schema_str = "" # Ensure schema_str is defined | ||
|
||
try: | ||
# Convert NLP query to SQL using text2sql | ||
sql_query = text2sql(nlp_query, schema_str, table_name=hf_link) | ||
st.write("### Generated SQL Query:") | ||
st.code(sql_query, language='sql') | ||
|
||
# sql_query = sql_query.replace("table_name", hf_link) | ||
# Execute the SQL query | ||
df = conn.execute(sql_query).df() | ||
# Display the results | ||
st.write("### Query Results:") | ||
st.dataframe(df) | ||
except Exception as e: | ||
st.error(f"An error occurred: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from typing import Dict, List | ||
|
||
import argilla as rg | ||
import typer | ||
from datasets import Dataset, load_dataset | ||
from openai import OpenAI | ||
from pydantic import BaseModel | ||
from retry import retry | ||
from tqdm import tqdm | ||
|
||
|
||
class Text2SQLSample(BaseModel): | ||
prompt: str | ||
schema: str | ||
query: str | ||
|
||
client = rg.Argilla(api_url="http://0.0.0.0:6900", api_key="argilla.apikey") | ||
# argilla / 12345678 / argilla.apikey | ||
|
||
def create_text2sql_dataset(dataset_name: str, data_samples: Dataset | List[Dict]): | ||
guidelines = """ | ||
Please examine the given DuckDB SQL question and context. | ||
Write the correct DuckDB SQL query that accurately answers the question based on the context provided. | ||
Ensure the query follows DuckDB SQL syntax and logic correctly. | ||
""" | ||
|
||
settings = rg.Settings( | ||
guidelines=guidelines, | ||
fields=[ | ||
rg.TextField( | ||
name="prompt", | ||
title="Prompt", | ||
use_markdown=False, | ||
), | ||
rg.TextField( | ||
name="schema", | ||
title="Schema", | ||
use_markdown=True, | ||
), | ||
rg.TextField( | ||
name="query", | ||
title="Query", | ||
use_markdown=True, | ||
), | ||
], | ||
questions=[ | ||
rg.TextQuestion( | ||
name="sql", | ||
title="Please write SQL for this query", | ||
description="Please write SQL for this query", | ||
required=True, | ||
use_markdown=True, | ||
) | ||
], | ||
) | ||
|
||
if 'admin' not in [x.name for x in client.workspaces.list()]: | ||
workspace = rg.Workspace(name="admin") | ||
workspace.create() | ||
|
||
dataset = rg.Dataset( | ||
name=dataset_name, | ||
workspace="admin", | ||
settings=settings, | ||
client=client, | ||
) | ||
dataset.create() | ||
records = [] | ||
for idx in range(len(data_samples)): | ||
x = rg.Record( | ||
fields={ | ||
"prompt": data_samples[idx]["prompt"], | ||
"schema": data_samples[idx]["schema"], | ||
"query": data_samples[idx]["query"], | ||
}, | ||
) | ||
records.append(x) | ||
dataset.records.log(records, batch_size=1000) | ||
|
||
def upload_duckdb_text2sql(): | ||
dataset_name = "motherduckdb/duckdb-text2sql-25k" | ||
raw_dataset = load_dataset(dataset_name, split="train") | ||
raw_datasets = raw_dataset.train_test_split(test_size=0.05, seed=42) | ||
|
||
raw_dataset.to_json(path_or_buf='./data/duckdb-text2sql-25k.json') | ||
raw_datasets['train'].to_json(path_or_buf='./data/train.json') | ||
raw_datasets['test'].to_json(path_or_buf='./data/test.json') | ||
|
||
create_text2sql_dataset(dataset_name='duckdb-text2sql-train', data_samples=raw_datasets['train'].to_list()) | ||
create_text2sql_dataset(dataset_name='duckdb-text2sql-test', data_samples=raw_datasets['test'].to_list()) | ||
|
||
|
||
@retry(tries=3, delay=1) | ||
def generate_synthetic_example() -> Dict[str, str]: | ||
client = OpenAI() | ||
|
||
prompt = """ | ||
Generate a example for text2sql task for DuckDB database: | ||
The example should include | ||
- schema: a valid database schema | ||
- prompt: a typical user question related to this table prompt | ||
- query: the corresponding SQL query to answer user prompt. | ||
Return only JSON. Use Japanese language for prompt (user question). | ||
""" | ||
|
||
chat_completion = client.beta.chat.completions.parse( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are DuckDB and SQL expert.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": prompt, | ||
}, | ||
], | ||
model="gpt-4o", | ||
response_format=Text2SQLSample, | ||
temperature=1, | ||
) | ||
sample = chat_completion.choices[0].message.parsed | ||
return sample.model_dump() | ||
|
||
|
||
def create_text2sql_dataset_synthetic(num_samples: int = 10): | ||
|
||
|
||
samples = [] | ||
for _ in tqdm(range(num_samples)): | ||
sample = generate_synthetic_example() | ||
samples.append(sample) | ||
|
||
dataset_name = "duckdb-text2sql-synthetic-jp" | ||
create_text2sql_dataset(dataset_name=dataset_name, data_samples=samples) | ||
|
||
if __name__ == "__main__": | ||
app = typer.Typer() | ||
app.command()(upload_duckdb_text2sql) | ||
app.command()(create_text2sql_dataset_synthetic) | ||
app() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.