Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finetuning component #502

Merged
merged 15 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions comps/finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# LLM Fine-tuning Microservice

LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task.

# 🚀1. Start Microservice with Python

## 1.1 Install Requirements

```bash
pip install -r requirements.txt
```

## 1.2 Start Finetuning Service with Python Script

### 1.2.1 Start Ray Cluster

OneCCL and Intel MPI libraries should be dynamically linked in every node before Ray starts:

```bash
source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh
```

Start Ray locally using the following command.

```bash
ray start --head
```

For a multi-node cluster, start additional Ray worker nodes with below command.

```bash
ray start --address='${head_node_ip}:6379'
```

### 1.2.2 Start Finetuning Service

```bash
export RAY_ADDRESS="ray://${ray_head_ip}:10001"
python finetuning/finetuning_service.py
```

# 🚀2. Consume Finetuning Service

## 2.1 Create fine-tuning job

Assuming a training file `alpaca_data.json` is uploaded, the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model:
chensuyue marked this conversation as resolved.
Show resolved Hide resolved

```bash
curl http://${your_ip}:8000/v1/fine_tuning/jobs \
-X POST \
-H "Content-Type: application/json" \
-d '{
"training_file": "alpaca_data.json",
"model": "meta-llama/Llama-2-7b-chat-hf"
}'
```
Empty file.
22 changes: 22 additions & 0 deletions comps/finetuning/docker/Dockerfile.finetune
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Use the same python version with ray
FROM python:3.10.14

WORKDIR /root/opea-finetune
chensuyue marked this conversation as resolved.
Show resolved Hide resolved

RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \
&& apt-get install -y vim htop net-tools dnsutils \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# COPY ./install-llm-on-ray.sh /tmp/install-llm-on-ray.sh
# RUN --mount=type=cache,target=/root/.cache/pip /tmp/install-llm-on-ray.sh

COPY ./ .

RUN --mount=type=cache,target=/root/.cache/pip cd ./llm-on-ray && pip install -v -e .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/

RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir --upgrade -r requirements.txt

RUN echo 'source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh' >> ~/.bashrc

CMD ["bash", "-c", "./run.sh"]
37 changes: 37 additions & 0 deletions comps/finetuning/finetune_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse

from llm_on_ray.finetune.finetune_config import FinetuneConfig
from pydantic_yaml import parse_yaml_raw_as
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments


class FineTuneCallback(TrainerCallback):
def __init__(self) -> None:
super().__init__()

def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
print("FineTuneCallback:", args, state)


def main():
parser = argparse.ArgumentParser(description="Runner for llm_on_ray-finetune")
parser.add_argument("--config_file", type=str, required=True, default=None)
args = parser.parse_args()
model_config_file = args.config_file

with open(model_config_file) as f:
finetune_config = parse_yaml_raw_as(FinetuneConfig, f).model_dump()

callback = FineTuneCallback()
finetune_config["Training"]["callbacks"] = [callback]

from llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main

llm_on_ray_finetune_main(finetune_config)


if __name__ == "__main__":
main()
40 changes: 40 additions & 0 deletions comps/finetuning/finetuning_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import uvicorn
from fastapi import BackgroundTasks, FastAPI
from handlers import (
handle_cancel_finetuning_job,
handle_create_finetuning_jobs,
handle_list_finetuning_jobs,
handle_retrieve_finetuning_job,
)
from models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest

app = FastAPI()


@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob)
def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
return handle_create_finetuning_jobs(request, background_tasks)


@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList)
def list_finetuning_jobs():
return handle_list_finetuning_jobs()


@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob)
def retrieve_finetuning_job(fine_tuning_job_id):
job = handle_retrieve_finetuning_job(fine_tuning_job_id)
return job


@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob)
def cancel_finetuning_job(fine_tuning_job_id):
job = handle_cancel_finetuning_job(fine_tuning_job_id)
return job


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
140 changes: 140 additions & 0 deletions comps/finetuning/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import random
import time
import uuid
from typing import Any, Dict, List, Set

from fastapi import BackgroundTasks, HTTPException
from llm_on_ray.finetune.finetune import main
from llm_on_ray.finetune.finetune_config import FinetuneConfig
from models import FineTuningJob, FineTuningJobEvent, FineTuningJobList, FineTuningJobsRequest
from pydantic_yaml import parse_yaml_raw_as, to_yaml_file
from ray.job_submission import JobSubmissionClient
from ray.train.base_trainer import TrainingFailedError
from ray.tune.logger import LoggerCallback

MODEL_CONFIG_FILE_MAP = {
"meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml",
"mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml",
}

DATASET_BASE_PATH = "datasets"

FineTuningJobID = str
CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs

global ray_client
ray_client: JobSubmissionClient = None

running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {}
finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {}


# Add a background task to periodicly update job status
def update_job_status(job_id: FineTuningJobID):
while True:
job_status = ray_client.get_job_status(finetuning_job_to_ray_job[job_id])
status = str(job_status).lower()
# Ray status "stopped" is OpenAI status "cancelled"
status = "cancelled" if status == "stopped" else status
print(f"Status of job {job_id} is '{status}'")
running_finetuning_jobs[job_id].status = status
if status == "finished" or status == "cancelled" or status == "failed":
break
time.sleep(CHECK_JOB_STATUS_INTERVAL)


def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks):
base_model = request.model
train_file = request.training_file
train_file_path = os.path.join(DATASET_BASE_PATH, train_file)

model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model)
if not model_config_file:
raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!")

if not os.path.exists(train_file_path):
Fixed Show fixed Hide fixed
raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!")

with open(model_config_file) as f:
finetune_config = parse_yaml_raw_as(FinetuneConfig, f)

finetune_config.Dataset.train_file = train_file_path

job = FineTuningJob(
id=f"ft-job-{uuid.uuid4()}",
model=base_model,
created_at=int(time.time()),
training_file=train_file,
hyperparameters={
"n_epochs": finetune_config.Training.epochs,
"batch_size": finetune_config.Training.batch_size,
"learning_rate_multiplier": finetune_config.Training.learning_rate,
},
status="running",
# TODO: Add seed in finetune config
seed=random.randint(0, 1000),
)

finetune_config_file = f"jobs/{job.id}.yaml"
to_yaml_file(finetune_config_file, finetune_config)

global ray_client
ray_client = JobSubmissionClient() if ray_client is None else ray_client

ray_job_id = ray_client.submit_job(
# Entrypoint shell command to execute
entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}",
# Path to the local directory that contains the script.py file
runtime_env={"working_dir": "./"},
)
print(f"Submitted Ray job: {ray_job_id} ...")

running_finetuning_jobs[job.id] = job
finetuning_job_to_ray_job[job.id] = ray_job_id

background_tasks.add_task(update_job_status, job.id)

return job


def handle_list_finetuning_jobs():
finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False)

return finetuning_jobs_list


def handle_retrieve_finetuning_job(fine_tuning_job_id):
job = running_finetuning_jobs.get(fine_tuning_job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")
return job


def handle_cancel_finetuning_job(fine_tuning_job_id):
ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id)
if ray_job_id is None:
raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!")

global ray_client
ray_client = JobSubmissionClient() if ray_client is None else ray_client
ray_client.stop_job(ray_job_id)

job = running_finetuning_jobs.get(fine_tuning_job_id)
job.status = "cancelled"
return job


# def cancel_all_jobs():
# global ray_client
# ray_client = JobSubmissionClient() if ray_client is None else ray_client
# # stop all jobs
# for job_id in finetuning_job_to_ray_job.values():
# ray_client.stop_job(job_id)

# for job_id in running_finetuning_jobs:
# running_finetuning_jobs[job_id].status = "cancelled"
# return running_finetuning_jobs
Empty file.
18 changes: 18 additions & 0 deletions comps/finetuning/llm_on_ray/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from llm_on_ray.common.logging import logger
from llm_on_ray.common.torch_config import TorchConfig
38 changes: 38 additions & 0 deletions comps/finetuning/llm_on_ray/common/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import glob
import importlib
import os

from llm_on_ray.common.logging import logger


def import_all_modules(basedir, prefix=None):
all_py_files = glob.glob(basedir + "/*.py")
modules = [os.path.basename(f) for f in all_py_files]

for module in modules:
if not module.startswith("_"):
module = module.rstrip(".py")
if prefix is None:
module_name = module
else:
module_name = f"{prefix}.{module}"
try:
importlib.import_module(module_name)
except Exception:
logger.warning(f"import {module_name} error", exc_info=True)
Loading
Loading