Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Apr 21, 2024
1 parent cc34fac commit 118e3ec
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 56 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ docker run -it --gpus '"device=1"' --ipc=host --net=host -v $PWD:/app fine-tune-

modal token set --token-id ak-r6mjZ61XGQtNoCDZGHrFLP --token-secret as-2m1UyDMKKwTo2uApJVGovn

https://modal.com/docs/reference/modal.config

export MODAL_TOKEN_ID=ak-r6mjZ61XGQtNoCDZGHrFLP
export MODAL_TOKEN_SECRET=as-2m1UyDMKKwTo2uApJVGovn

modal run
modal run text2sql_training/modal_training.py
67 changes: 20 additions & 47 deletions text2sql_training/llm_stf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@
from datasets import load_dataset
from random import randint
from huggingface_hub import hf_hub_download

from collections import defaultdict

HF_TOKEN_READ = os.getenv("HF_TOKEN_READ")
HF_TOKEN_WRITE = os.getenv("HF_TOKEN_WRITE")


class DataConfig(Config):
# dataset_name: str = "motherduckdb/duckdb-text2sql-25k"
dataset_name: str = "b-mc2/sql-create-context"

train_data_path: str = "train_dataset-sql.json"
test_data_path: str = "test_dataset-sql.json"

test_size: float = 0.1
sample_training: int = 5000

Expand Down Expand Up @@ -113,7 +110,7 @@ def test_data(context: AssetExecutionContext, create_text_to_sql_dataset):

return dataset

def run_training(pretrained_model_id: str, peft_model_id: str, train_data):
def run_training(pretrained_model_id: str, peft_model_id: str, train_data) -> str:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
Expand Down Expand Up @@ -215,36 +212,29 @@ def run_training(pretrained_model_id: str, peft_model_id: str, train_data):

return hub_model_id

@asset(group_name="model", compute_kind="modal")
def trained_model(
context: AssetExecutionContext, config: ModelTrainingConfig, train_data
):

@asset(group_name="model", compute_kind="modal-labs")
def trained_model(context: AssetExecutionContext, config: ModelTrainingConfig, train_data):
hub_model_id = run_training(pretrained_model_id=config.pretrained_model_id, peft_model_id=config.peft_model_id, train_data=train_data)

context.add_output_metadata(
{
"model_url": MetadataValue.url(f"https://huggingface.co/{hub_model_id}"),
}
)
context.add_output_metadata({"model_url": MetadataValue.url(f"https://huggingface.co/{hub_model_id}")})
return hub_model_id





@asset(group_name="model")
@asset(group_name="model", compute_kind="python")
def model_card(context: AssetExecutionContext, trained_model):

model_card_path = hf_hub_download(repo_id=trained_model, filename="README.md")

with open(model_card_path, "r") as f:
content = f.read()
context.add_output_metadata(
{
"content": MetadataValue.md(content),
})

context.add_output_metadata({"content": MetadataValue.md(content)})
return content


@asset(group_name="model", compute_kind="modal")
@asset(group_name="model", compute_kind="python")
def test_results(context: AssetExecutionContext, test_data, trained_model, config: ModelTrainingConfig):
tokenizer = AutoTokenizer.from_pretrained(config.peft_model_id)
model = AutoPeftModelForCausalLM.from_pretrained(
Expand All @@ -262,9 +252,9 @@ def test_results(context: AssetExecutionContext, test_data, trained_model, confi
pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


results = defaultdict(list)
number_of_eval_samples = 100
for s in test_data.shuffle().select(range(number_of_eval_samples)):
for s in test_data.select(range(number_of_eval_samples)):
query = s['messages'][1]['content']


Expand All @@ -277,33 +267,16 @@ def test_results(context: AssetExecutionContext, test_data, trained_model, confi
original_sql = s['messages'][2]['content'].lower()
generated_sql = outputs[0]['generated_text'][len(prompt):].strip().lower()

inference_samples.append({
'query': test_data[rand_idx]['messages'][1]['content'],
'original_sql': original_sql,
'generated_sql': generated_sql,
'match': original_sql == generated_sql
})
results['query'].append(query)
results['original_sql'].append(original_sql)
results['generated_sql'].append(generated_sql)
results['hard_match'].append(original_sql == generated_sql)


rouge = evaluate.load('rouge')
rouge.compute(predictions=original_sql, references=generated_sql)

# def evaluate_sample(sample):
# prompt = pipe.tokenizer.apply_chat_template(sample["messages"][:2], tokenize=False, add_generation_prompt=True)
# outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=terminators, pad_token_id=pipe.tokenizer.pad_token_id)
# generated_sql = outputs[0]['generated_text'][len(prompt):].strip().lower()
# original_sql = sample["messages"][2]["content"].lower()
# return 1 if generated_sql == original_sql else 0

success_rate = []
number_of_eval_samples = 100
# iterate over eval dataset and predict
for s in test_data.shuffle().select(range(number_of_eval_samples)):
success_rate.append(evaluate_sample(s))

# compute accuracy
accuracy = sum(success_rate)/len(success_rate)
rouge.compute(predictions=results['generated_sql'], references=results['original_sql'])


context.add_output_metadata(
{
"inference_samples": MetadataValue.json(inference_samples),
Expand Down
21 changes: 13 additions & 8 deletions text2sql_training/modal_training.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import modal
from modal import Image
from datasets import load_dataset
from datasets import disable_caching
import pandas as pd
import os

app = modal.App("example-hello-world")
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/fine-tune-llm-in-2024:main")
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/fine-tune-llm-in-2024:main").env({"HF_TOKEN": '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', "HF_TOKEN_WRITE": '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'})

# ENV HF_TOKEN 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
# ENV HF_TOKEN_WRITE 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

@app.function(image=custom_image, gpu="A100")
def foo():
from datasets import load_dataset
import text2sql_training
@app.function(image=custom_image, gpu="A100", mounts=[modal.Mount.from_local_python_packages("text2sql_training", "text2sql_training")], timeout=15 * 60)
def run_training_modal(train_data_pandas: pd.DataFrame):
from datasets import Dataset
from text2sql_training.llm_stf import run_training

train_data = load_dataset("json", data_files='train_dataset-sql.json', split="train")
model_url = run_training(pretrained_model_id='meta-llama/Meta-Llama-3-8B-Instruct', peft_model_id='modal-test', train_data=train_data)
model_url = run_training(pretrained_model_id='meta-llama/Meta-Llama-3-8B-Instruct', peft_model_id='modal-test', train_data=Dataset.from_pandas(train_data_pandas))
return model_url


@app.local_entrypoint()
def main():
result = foo.remote()
train_data = load_dataset("json", data_files='train_dataset-sql.json', split="train")
result = run_training_modal.remote(train_data_pandas=train_data.to_pandas())
print(result)

0 comments on commit 118e3ec

Please sign in to comment.