diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index e1276716c4..9d707fa688 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -3,7 +3,7 @@ import asyncio import multiprocessing -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType @@ -154,6 +154,7 @@ def register_microservice( output_datatype: Type[Any] = TextDoc, provider: Optional[str] = None, provider_endpoint: Optional[str] = None, + methods: List[str] = ["POST"], ): def decorator(func): if name not in opea_microservices: @@ -173,7 +174,8 @@ def decorator(func): provider_endpoint=provider_endpoint, ) opea_microservices[name] = micro_service - opea_microservices[name].app.router.add_api_route(endpoint, func, methods=["POST"]) + opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) + return func return decorator diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 0b3094e1c6..0a8b2de005 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -539,3 +539,225 @@ def check_requests(request) -> Optional[JSONResponse]: ) return None + + +class Hyperparameters(BaseModel): + batch_size: Optional[Union[Literal["auto"], int]] = "auto" + """Number of examples in each batch. + + A larger batch size means that model parameters are updated less frequently, but with lower variance. + """ + + learning_rate_multiplier: Optional[Union[Literal["auto"], float]] = "auto" + """Scaling factor for the learning rate. + + A smaller learning rate may be useful to avoid overfitting. + """ + + n_epochs: Optional[Union[Literal["auto"], int]] = "auto" + """The number of epochs to train the model for. + + An epoch refers to one full cycle through the training dataset. "auto" decides + the optimal number of epochs based on the size of the dataset. If setting the + number manually, we support any number between 1 and 50 epochs. + """ + + +class FineTuningJobWandbIntegration(BaseModel): + project: str + """The name of the project that the new run will be created under.""" + + entity: Optional[str] = None + """The entity to use for the run. + + This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered + WandB API key is used. + """ + + name: Optional[str] = None + """A display name to set for the run. + + If not set, we will use the Job ID as the name. + """ + + tags: Optional[List[str]] = None + """A list of tags to be attached to the newly created run. + + These tags are passed through directly to WandB. Some default tags are generated + by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + """ + + +class FineTuningJobWandbIntegrationObject(BaseModel): + type: Literal["wandb"] + """The type of the integration being enabled for the fine-tuning job.""" + + wandb: FineTuningJobWandbIntegration + """The settings for your integration with Weights and Biases. + + This payload specifies the project that metrics will be sent to. Optionally, you + can set an explicit display name for your run, add tags to your run, and set a + default entity (team, username, etc) to be associated with your run. + """ + + +class FineTuningJobsRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/create + model: str + """The name of the model to fine-tune.""" + + training_file: str + """The ID of an uploaded file that contains training data.""" + + hyperparameters: Optional[Hyperparameters] = None + """The hyperparameters used for the fine-tuning job.""" + + suffix: Optional[str] = None + """A string of up to 64 characters that will be added to your fine-tuned model name.""" + + validation_file: Optional[str] = None + """The ID of an uploaded file that contains validation data.""" + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for your fine-tuning job.""" + + seed: Optional[str] = None + + +class Error(BaseModel): + code: str + """A machine-readable error code.""" + + message: str + """A human-readable error message.""" + + param: Optional[str] = None + """The parameter that was invalid, usually `training_file` or `validation_file`. + + This field will be null if the failure was not parameter-specific. + """ + + +class FineTuningJob(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/object + id: str + """The object identifier, which can be referenced in the API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the fine-tuning job was created.""" + + error: Optional[Error] = None + """For fine-tuning jobs that have `failed`, this will contain more information on + the cause of the failure.""" + + fine_tuned_model: Optional[str] = None + """The name of the fine-tuned model that is being created. + + The value will be null if the fine-tuning job is still running. + """ + + finished_at: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job was finished. + + The value will be null if the fine-tuning job is still running. + """ + + hyperparameters: Hyperparameters + """The hyperparameters used for the fine-tuning job. + + See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) + for more details. + """ + + model: str + """The base model that is being fine-tuned.""" + + object: Literal["fine_tuning.job"] = "fine_tuning.job" + """The object type, which is always "fine_tuning.job".""" + + organization_id: Optional[str] = None + """The organization that owns the fine-tuning job.""" + + result_files: List[str] = None + """The compiled results file ID(s) for the fine-tuning job. + + You can retrieve the results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + status: Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"] + """The current status of the fine-tuning job, which can be either + `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.""" + + trained_tokens: Optional[int] = None + """The total number of billable tokens processed by this fine-tuning job. + + The value will be null if the fine-tuning job is still running. + """ + + training_file: str + """The file ID used for training. + + You can retrieve the training data with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + validation_file: Optional[str] = None + """The file ID used for validation. + + You can retrieve the validation results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for this fine-tuning job.""" + + seed: Optional[int] = None + """The seed used for the fine-tuning job.""" + + estimated_finish: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job is estimated to + finish. + + The value will be null if the fine-tuning job is not running. + """ + + +class FineTuningJobIDRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/retrieve + # https://platform.openai.com/docs/api-reference/fine-tuning/cancel + fine_tuning_job_id: str + """The ID of the fine-tuning job.""" + + +class FineTuningJobListRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + after: Optional[str] = None + """Identifier for the last job from the previous pagination request.""" + + limit: Optional[int] = 20 + """Number of fine-tuning jobs to retrieve.""" + + +class FineTuningJobList(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + object: str = "list" + """The object type, which is always "list". + + This indicates that the returned data is a list of fine-tuning jobs. + """ + + data: List[FineTuningJob] + """A list containing FineTuningJob objects.""" + + has_more: bool + """Indicates whether there are more fine-tuning jobs beyond the current list. + + If true, additional requests can be made to retrieve more jobs. + """ diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md new file mode 100644 index 0000000000..411395ec95 --- /dev/null +++ b/comps/finetuning/README.md @@ -0,0 +1,121 @@ +# LLM Fine-tuning Microservice + +LLM Fine-tuning microservice involves adapting a base model to a specific task or dataset to improve its performance on that task. + +# 🚀1. Start Microservice with Python (Optional 1) + +## 1.1 Install Requirements + +```bash +python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +python -m pip install intel-extension-for-pytorch +python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ +pip install -r requirements.txt +``` + +## 1.2 Start Finetuning Service with Python Script + +### 1.2.1 Start Ray Cluster + +OneCCL and Intel MPI libraries should be dynamically linked in every node before Ray starts: + +```bash +source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl; print(torch_ccl.cwd)")/env/setvars.sh +``` + +Start Ray locally using the following command. + +```bash +ray start --head +``` + +For a multi-node cluster, start additional Ray worker nodes with below command. + +```bash +ray start --address='${head_node_ip}:6379' +``` + +### 1.2.2 Start Finetuning Service + +```bash +export HF_TOKEN=${your_huggingface_token} +python finetuning_service.py +``` + +# 🚀2. Start Microservice with Docker (Optional 2) + +## 2.1 Setup on CPU + +### 2.1.1 Build Docker Image + +Build docker image with below command: + +```bash +export HF_TOKEN=${your_huggingface_token} +cd ../../ +docker build -t opea/finetuning:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/docker/Dockerfile_cpu . +``` + +### 2.1.2 Run Docker with CLI + +Start docker container with below command: + +```bash +docker run -d --name="finetuning-server" -p 8005:8005 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +``` + +## 2.2 Setup on Gaudi2 + +### 2.2.1 Build Docker Image + +Build docker image with below command: + +```bash +cd ../../ +docker build -t opea/finetuning-gaudi:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/finetuning/docker/Dockerfile_hpu . +``` + +### 2.2.2 Run Docker with CLI + +Start docker container with below command: + +```bash +export HF_TOKEN=${your_huggingface_token} +docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p 8005:8005 -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -e https_proxy=$https_proxy -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e HF_TOKEN=$HF_TOKEN opea/finetuning-gaudi:latest +``` + +# 🚀3. Consume Finetuning Service + +## 3.1 Create fine-tuning job + +Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: + +```bash +# upload a training file +curl http://${your_ip}:8005/v1/finetune/upload_training_files -X POST -H "Content-Type: multipart/form-data" -F "files=@./alpaca_data.json" + +# create a finetuning job +curl http://${your_ip}:8005/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "alpaca_data.json", + "model": "meta-llama/Llama-2-7b-chat-hf" + }' + +# list finetuning jobs +curl http://${your_ip}:8005/v1/fine_tuning/jobs -X GET + +# retrieve one finetuning job +curl http://localhost:8005/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + +# cancel one finetuning job + +curl http://localhost:8005/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + +# list checkpoints of a finetuning job +curl http://${your_ip}:8005/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}' + +``` diff --git a/comps/finetuning/datasets/.gitkeep b/comps/finetuning/datasets/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/comps/finetuning/docker/Dockerfile_cpu b/comps/finetuning/docker/Dockerfile_cpu new file mode 100644 index 0000000000..1cb391af8e --- /dev/null +++ b/comps/finetuning/docker/Dockerfile_cpu @@ -0,0 +1,38 @@ +# Use the same python version with ray +FROM python:3.10.14 + +ARG HF_TOKEN + +ENV HF_TOKEN=$HF_TOKEN + +RUN apt-get update -y && apt-get install -y vim htop net-tools dnsutils + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +COPY comps /home/user/comps + +RUN chown -R user /home/user/comps/finetuning + +USER user + +ENV PATH=$PATH:/home/user/.local/bin + +RUN pip install --no-cache-dir --upgrade pip && \ + python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \ + python -m pip install intel-extension-for-pytorch && \ + python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + pip install --no-cache-dir -r /home/user/comps/finetuning/requirements.txt + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/finetuning + +RUN echo PKGPATH=$(python3 -c "import pkg_resources; print(pkg_resources.get_distribution('oneccl-bind-pt').location)") >> run.sh && \ + echo 'export LD_LIBRARY_PATH=$PKGPATH/oneccl_bindings_for_pytorch/opt/mpi/lib/:$LD_LIBRARY_PATH' >> run.sh && \ + echo 'source $PKGPATH/oneccl_bindings_for_pytorch/env/setvars.sh' >> run.sh && \ + echo ray start --head >> run.sh && \ + echo python finetuning_service.py >> run.sh + +CMD bash run.sh \ No newline at end of file diff --git a/comps/finetuning/docker/Dockerfile_hpu b/comps/finetuning/docker/Dockerfile_hpu new file mode 100644 index 0000000000..1277d76c16 --- /dev/null +++ b/comps/finetuning/docker/Dockerfile_hpu @@ -0,0 +1,31 @@ +# Use the same python version with ray +FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + +ENV DEVICE="hpu" + +RUN apt-get update -y && apt-get install -y vim htop net-tools dnsutils + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +COPY comps /home/user/comps + +RUN chown -R user /home/user/comps/finetuning + +USER user + +ENV PATH=$PATH:/home/user/.local/bin + +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r /home/user/comps/finetuning/requirements.txt && \ + pip install --no-cache-dir optimum-habana + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/finetuning + +ENTRYPOINT ["/bin/bash", "launch.sh"] + +# CMD ["/bin/bash"] + diff --git a/comps/finetuning/finetune_runner.py b/comps/finetuning/finetune_runner.py new file mode 100644 index 0000000000..1ddfc46424 --- /dev/null +++ b/comps/finetuning/finetune_runner.py @@ -0,0 +1,38 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from pydantic_yaml import parse_yaml_raw_as +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments + +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig + + +class FineTuneCallback(TrainerCallback): + def __init__(self) -> None: + super().__init__() + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + print("FineTuneCallback:", args, state) + + +def main(): + parser = argparse.ArgumentParser(description="Runner for llm_on_ray-finetune") + parser.add_argument("--config_file", type=str, required=True, default=None) + args = parser.parse_args() + model_config_file = args.config_file + + with open(model_config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f).model_dump() + + callback = FineTuneCallback() + finetune_config["Training"]["callbacks"] = [callback] + + from comps.finetuning.llm_on_ray.finetune.finetune import main as llm_on_ray_finetune_main + + llm_on_ray_finetune_main(finetune_config) + + +if __name__ == "__main__": + main() diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py new file mode 100644 index 0000000000..fabb32bc40 --- /dev/null +++ b/comps/finetuning/finetuning_service.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import urllib.parse +from typing import List, Optional, Union + +from fastapi import BackgroundTasks, File, UploadFile + +from comps import opea_microservices, register_microservice +from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest +from comps.finetuning.handlers import ( + DATASET_BASE_PATH, + handle_cancel_finetuning_job, + handle_create_finetuning_jobs, + handle_list_finetuning_checkpoints, + handle_list_finetuning_jobs, + handle_retrieve_finetuning_job, + save_content_to_local_disk, +) + + +@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8005) +def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): + return handle_create_finetuning_jobs(request, background_tasks) + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8005, methods=["GET"] +) +def list_finetuning_jobs(): + return handle_list_finetuning_jobs() + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/retrieve", host="0.0.0.0", port=8005 +) +def retrieve_finetuning_job(request: FineTuningJobIDRequest): + job = handle_retrieve_finetuning_job(request) + return job + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/cancel", host="0.0.0.0", port=8005 +) +def cancel_finetuning_job(request: FineTuningJobIDRequest): + job = handle_cancel_finetuning_job(request) + return job + + +@register_microservice( + name="opea_service@finetuning", + endpoint="/v1/finetune/upload_training_files", + host="0.0.0.0", + port=8005, +) +async def upload_training_files( + files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), +): + if files: + if not isinstance(files, list): + files = [files] + for file in files: + filename = urllib.parse.quote(file.filename, safe="") + save_path = os.path.join(DATASET_BASE_PATH, filename) + await save_content_to_local_disk(save_path, file) + + return {"status": 200, "message": "Training files uploaded."} + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/finetune/list_checkpoints", host="0.0.0.0", port=8005 +) +def list_checkpoints(request: FineTuningJobIDRequest): + checkpoints = handle_list_finetuning_checkpoints(request) + return {"status": 200, "checkpoints": str(checkpoints)} + + +if __name__ == "__main__": + opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py new file mode 100644 index 0000000000..2bdab42a92 --- /dev/null +++ b/comps/finetuning/handlers.py @@ -0,0 +1,186 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +import time +import uuid +from pathlib import Path +from typing import Dict + +from fastapi import BackgroundTasks, HTTPException +from pydantic_yaml import parse_yaml_raw_as, to_yaml_file +from ray.job_submission import JobSubmissionClient + +from comps.cores.proto.api_protocol import ( + FineTuningJob, + FineTuningJobIDRequest, + FineTuningJobList, + FineTuningJobsRequest, +) +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig + +MODEL_CONFIG_FILE_MAP = { + "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", + "mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml", +} + +DATASET_BASE_PATH = "datasets" +JOBS_PATH = "jobs" +if not os.path.exists(DATASET_BASE_PATH): + os.mkdir(DATASET_BASE_PATH) + +if not os.path.exists(JOBS_PATH): + os.mkdir(JOBS_PATH) + +FineTuningJobID = str +CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs + +global ray_client +ray_client: JobSubmissionClient = None + +running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {} +finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {} + + +# Add a background task to periodicly update job status +def update_job_status(job_id: FineTuningJobID): + while True: + job_status = ray_client.get_job_status(finetuning_job_to_ray_job[job_id]) + status = str(job_status).lower() + # Ray status "stopped" is OpenAI status "cancelled" + status = "cancelled" if status == "stopped" else status + print(f"Status of job {job_id} is '{status}'") + running_finetuning_jobs[job_id].status = status + if status == "finished" or status == "cancelled" or status == "failed": + break + time.sleep(CHECK_JOB_STATUS_INTERVAL) + + +def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): + base_model = request.model + train_file = request.training_file + train_file_path = os.path.join(DATASET_BASE_PATH, train_file) + + model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model) + if not model_config_file: + raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!") + + if not os.path.exists(train_file_path): + raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!") + + with open(model_config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + + finetune_config.Dataset.train_file = train_file_path + + if request.hyperparameters is not None: + if request.hyperparameters.epochs != "auto": + finetune_config.Training.epochs = request.hyperparameters.epochs + + if request.hyperparameters.batch_size != "auto": + finetune_config.Training.batch_size = request.hyperparameters.batch_size + + if request.hyperparameters.learning_rate_multiplier != "auto": + finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier + + if os.getenv("HF_TOKEN", None): + finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None) + + job = FineTuningJob( + id=f"ft-job-{uuid.uuid4()}", + model=base_model, + created_at=int(time.time()), + training_file=train_file, + hyperparameters={ + "n_epochs": finetune_config.Training.epochs, + "batch_size": finetune_config.Training.batch_size, + "learning_rate_multiplier": finetune_config.Training.learning_rate, + }, + status="running", + seed=random.randint(0, 1000) if request.seed is None else request.seed, + ) + finetune_config.General.output_dir = os.path.join(JOBS_PATH, job.id) + if os.getenv("DEVICE", ""): + print(f"specific device: {os.getenv('DEVICE')}") + finetune_config.Training.device = os.getenv("DEVICE") + + finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" + to_yaml_file(finetune_config_file, finetune_config) + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + + ray_job_id = ray_client.submit_job( + # Entrypoint shell command to execute + entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}", + # Path to the local directory that contains the script.py file + runtime_env={"working_dir": "./"}, + ) + print(f"Submitted Ray job: {ray_job_id} ...") + + running_finetuning_jobs[job.id] = job + finetuning_job_to_ray_job[job.id] = ray_job_id + + background_tasks.add_task(update_job_status, job.id) + + return job + + +def handle_list_finetuning_jobs(): + finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False) + + return finetuning_jobs_list + + +def handle_retrieve_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + return job + + +def handle_cancel_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) + if ray_job_id is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + ray_client.stop_job(ray_job_id) + + job = running_finetuning_jobs.get(fine_tuning_job_id) + job.status = "cancelled" + return job + + +async def save_content_to_local_disk(save_path: str, content): + save_path = Path(save_path) + try: + if isinstance(content, str): + with open(save_path, "w", encoding="utf-8") as file: + file.write(content) + else: + with save_path.open("wb") as fout: + content = await content.read() + fout.write(content) + except Exception as e: + print(f"Write file failed. Exception: {e}") + raise Exception(status_code=500, detail=f"Write file {save_path} failed. Exception: {e}") + + +def handle_list_finetuning_checkpoints(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + output_dir = os.path.join(JOBS_PATH, job.id) + checkpoints = [] + if os.path.exists(output_dir): + checkpoints = os.listdir(output_dir) + return checkpoints diff --git a/comps/finetuning/jobs/.gitkeep b/comps/finetuning/jobs/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/comps/finetuning/lanuch.sh b/comps/finetuning/lanuch.sh new file mode 100644 index 0000000000..a7e249b6f3 --- /dev/null +++ b/comps/finetuning/lanuch.sh @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +if [[ -n "$RAY_PORT" ]];then + export RAY_ADDRESS=http://127.0.0.1:$RAY_PORT + ray start --head --port $RAY_PORT +else + export RAY_ADDRESS=http://127.0.0.1:8265 + ray start --head +fi + +python finetuning_service.py diff --git a/comps/finetuning/llm_on_ray/common/__init__.py b/comps/finetuning/llm_on_ray/common/__init__.py new file mode 100644 index 0000000000..a4ad1e878e --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +from .logging import logger +from .torch_config import TorchConfig diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py new file mode 100644 index 0000000000..136d2526f8 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +import glob +import importlib +import os + +from .logging import logger + + +def import_all_modules(basedir, prefix=None): + all_py_files = glob.glob(basedir + "/*.py") + modules = [os.path.basename(f) for f in all_py_files] + + for module in modules: + if not module.startswith("_"): + module = module.rstrip(".py") + if prefix is None: + module_name = module + else: + module_name = f"{prefix}.{module}" + try: + importlib.import_module(module_name) + except Exception: + logger.warning(f"import {module_name} error", exc_info=True) diff --git a/comps/finetuning/llm_on_ray/common/logging.py b/comps/finetuning/llm_on_ray/common/logging.py new file mode 100644 index 0000000000..e2aec567a2 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/logging.py @@ -0,0 +1,56 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +import functools +import logging +import logging.config +import traceback + +__all__ = ["logger", "get_logger"] + +use_accelerate_log = False +logger_name = "common" + +logging_config = { + "version": 1, + "loggers": { + "root": {"level": "INFO", "handlers": ["consoleHandler"]}, + "common": { + "level": "INFO", + "handlers": ["consoleHandler"], + "qualname": "common", + "propagate": 0, + }, + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "standardFormatter", + }, + }, + "formatters": { + "standardFormatter": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "", + } + }, +} + +if logging_config is not None: + try: + logging.config.dictConfig(logging_config) + except Exception: + traceback.print_exc() + exit(1) + +if use_accelerate_log: + import accelerate + + get_logger = functools.partial(accelerate.logging.get_logger, name=logger_name) +else: + get_logger = functools.partial(logging.getLogger, name=logger_name) + +logger = get_logger() diff --git a/comps/finetuning/llm_on_ray/common/torch_config.py b/comps/finetuning/llm_on_ray/common/torch_config.py new file mode 100644 index 0000000000..9e3f48a7c3 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/torch_config.py @@ -0,0 +1,72 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +import os +import sys +from dataclasses import dataclass +from typing import Optional + +from ray.train._internal.worker_group import WorkerGroup +from ray.train.torch.config import TorchConfig as RayTorchConfig +from ray.train.torch.config import _TorchBackend + +# The package importlib_metadata is in a different place, depending on the Python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +@dataclass +class TorchConfig(RayTorchConfig): + device: Optional[str] = None + + @property + def backend_cls(self): + EnableCCLBackend.device = self.device + return EnableCCLBackend + + +def xpu_libs_import(): + """Try to import IPEX and oneCCL.""" + try: + import intel_extension_for_pytorch + except ImportError: + raise ImportError("Please install intel_extension_for_pytorch") + try: + ccl_version = importlib_metadata.version("oneccl_bind_pt") + if ccl_version >= "1.12": + import oneccl_bindings_for_pytorch + else: + import torch_ccl + except ImportError as ccl_not_exist: + raise ImportError("Please install torch-ccl") from ccl_not_exist + + +def hpu_libs_import(): + """Try to import habana frameworkfs for torch.""" + try: + import habana_frameworks.torch # noqa: F401 + except ImportError as habana_not_exist: + raise ImportError("Please install habana_frameworks") from habana_not_exist + + +def _set_torch_distributed_env_vars(device): + if device is not None: + os.environ["ACCELERATE_TORCH_DEVICE"] = device + + +class EnableCCLBackend(_TorchBackend): + device: Optional[str] = None + + def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + libs_import = hpu_libs_import if self.device is not None and self.device.startswith("hpu") else xpu_libs_import + for i in range(len(worker_group)): + worker_group.execute_single_async(i, libs_import) + super().on_start(worker_group, backend_config) + + def on_training_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + super().on_training_start(worker_group, backend_config) + worker_group.execute(_set_torch_distributed_env_vars, self.device) diff --git a/comps/finetuning/llm_on_ray/finetune/__init__.py b/comps/finetuning/llm_on_ray/finetune/__init__.py new file mode 100644 index 0000000000..0262e494a9 --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py new file mode 100644 index 0000000000..ab5efcc09d --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -0,0 +1,196 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +import copy +import re +from itertools import chain + +import torch + +IGNORE_INDEX = -100 + + +class DataProcessor: + # We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release) + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.end = tokenizer.eos_token + self.intro = ( + "Below is an instruction that describes a task. Write a response that appropriately completes the request." + ) + self.instruction = "### Instruction:\n" + self.input = "### Input:\n" + self.response = "### Response:\n" + self.padding_side = config["Dataset"].get("padding_side", "right") + self.truncation_side = config["Dataset"].get("truncation_side", "right") + self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512) + self.max_source_length = config["Dataset"].get("max_source_length", 384) + self.truncation = config["Dataset"].get("truncation", True) + self.padding = config["Dataset"].get("padding", True) + self.mask_input = config["Dataset"].get("mask_input", True) + self.mask_response = config["Dataset"].get("mask_response", True) + + def make_prompt(self, examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] + for rec in examples: + instruction = rec["instruction"] + response = rec["input"] + context = rec.get("output") + if not instruction: + raise ValueError(f"Expected an instruction in: {rec}") + # if not response: + # raise ValueError(f"Expected a response in: {rec}") + if context: + prompt = ( + self.intro + + self.end + + "\n" + + self.instruction + + instruction + + self.input + + context + + self.end + + "\n" + + self.response + ) + prompts["prompt_sources"].append(prompt) + else: + prompt = self.intro + self.end + "\n" + self.instruction + instruction + self.end + "\n" + self.response + prompts["prompt_sources"].append(prompt) + prompt_response = response + self.end + prompts["prompt_targets"].append(prompt_response) + return prompts + + def __truncate_sequences(self, sequences, max_length): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40 + """ + words_to_cut = sum(list(map(len, sequences))) - max_length + if words_to_cut <= 0: + return sequences + + while words_to_cut > 0 and len(sequences) > 0: + words_to_cut -= len(sequences[0]) + sequences = sequences[1:] + return sequences + + def tokenize_by_neural_chat(self, examples): + """ + Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225 + The only differences are: + - using our own prompt style + - add left or right padding and truncation + - add mask_input and mask_response + """ + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + assistant_tokens = self.tokenizer.tokenize(self.response) + header = self.intro + self.end + "\n" + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for instruction, response in zip(examples[keys[0]], examples[keys[1]]): + convs = re.findall( + r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end), + instruction, + re.DOTALL, + ) + convs_tokens = [self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs] + header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n") + max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens) + truncated_convs = self.__truncate_sequences(convs_tokens, max_input) + if len(truncated_convs) == 0: + truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]] + + prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens] + prompt_ids = [self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens] + prompt_ids = list(chain(*prompt_ids)) + + resp_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(response.strip())) + # keep last and eos_id + max_resp = self.max_seq_length - len(prompt_ids) - 1 + + # truncating response + if len(resp_ids) > max_resp: + if self.truncation_side == "right": + resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:] + else: + resp_ids = resp_ids[-max_resp:] + + # masking + input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id] + if self.mask_input: + labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id] + elif self.mask_response: + labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id] + else: + labels = input_ids + + # padding + input_len = len(input_ids) + pad_len = self.max_seq_length - input_len + if self.padding_side == "right": + input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len + labels = labels + [IGNORE_INDEX] * pad_len + attention_mask = [1] * input_len + [0] * pad_len + else: + input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids + labels = [IGNORE_INDEX] * pad_len + labels + attention_mask = [0] * pad_len + [1] * input_len + + assert len(input_ids) == self.max_seq_length + assert len(prompt_ids) <= self.max_source_length + assert len(labels) == len(input_ids) == len(attention_mask) + + examples["input_ids"].append(torch.tensor(input_ids)) + examples["labels"].append(labels) + examples["attention_mask"].append(attention_mask) + + return examples + + def tokenize(self, examples): + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for s, t in zip(examples[keys[0]], examples[keys[1]]): + results = self.tokenizer( + s + t, + padding=self.padding, + truncation=self.truncation, + return_tensors=None, + max_length=self.max_length, + ) + + input_ids = results["input_ids"] + input_len = len(input_ids) + labels = copy.deepcopy(input_ids) + if self.mask_input or self.mask_response: + sources_tokenized = self.tokenizer( + s, + padding=False, + truncation=True, + return_tensors=None, + max_length=self.max_length, + ) + input_id_len = len(sources_tokenized["input_ids"]) + # mask input + if self.mask_input: + labels[:input_id_len] = [IGNORE_INDEX] * input_id_len + # mask response + if self.mask_response: + labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len) + + examples["input_ids"].append(results["input_ids"]) + examples["labels"].append(labels) + examples["attention_mask"].append(results["attention_mask"]) + return examples diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py new file mode 100644 index 0000000000..f268800f23 --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -0,0 +1,459 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +#!/usr/bin/env python + +import argparse +import copy +import os +import re +import sys +from itertools import chain +from typing import Any, Dict, Optional, Union + +import datasets +import ray +import torch +import transformers +from peft import LoraConfig, get_peft_model +from pydantic_yaml import parse_yaml_raw_as +from ray.air import FailureConfig, RunConfig +from ray.air.config import ScalingConfig +from ray.train.torch import TorchTrainer + +from comps.finetuning.llm_on_ray import common +from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor +from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig + + +def adapt_transformers_to_device(config: Dict): + device = config["Training"]["device"] + if device in ["hpu"]: + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + # adapt transformers to gaudi + adapt_transformers_to_gaudi() + + +def set_seed(config: Dict): + seed = config["Training"].get("seed", None) + if seed is None: + return + device = config["Training"]["device"] + if device in ["cpu", "gpu"]: + from accelerate.utils import set_seed as _set_seed + + _set_seed(seed) + elif device in ["hpu"]: + from optimum.habana.utils import set_seed as _set_seed + + _set_seed(seed) + + +def convert_to_training_args(cls, config: Dict): + device = config["Training"]["device"] + accelerate_mode = config["Training"]["accelerate_mode"] + save_strategy = config["General"]["save_strategy"] + + args = { + "output_dir": config["General"]["output_dir"], + "report_to": config["General"]["report_to"], + "resume_from_checkpoint": config["General"]["resume_from_checkpoint"], + "gradient_checkpointing": config["General"]["enable_gradient_checkpointing"], + "save_strategy": save_strategy if save_strategy != "False" else "no", + "bf16": config["Training"]["mixed_precision"] == "bf16", + "num_train_epochs": config["Training"]["epochs"], + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "optim": config["Training"]["optimizer"], + "learning_rate": config["Training"]["learning_rate"], + "logging_steps": config["Training"]["logging_steps"], + "lr_scheduler_type": config["Training"]["lr_scheduler"], + "weight_decay": config["Training"]["weight_decay"], + "gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"], + "do_train": True, + } + + # set attr do_eval + vf = config["Dataset"].get("validation_file", None) + vsp = config["Dataset"].get("validation_split_percentage", 0) + if vf is not None or (vsp / 100 > 0.0 and vsp / 100 < 1.0): + args.update({"do_eval": True}) + + # set attr max_steps + if config["Training"]["max_train_steps"] is not None: + args.update({"max_steps": config["Training"]["max_train_steps"]}) + + # set attr for device cpu + if device == "cpu": + if hasattr(cls, "use_cpu"): + args.update({"use_cpu": True}) + if hasattr(cls, "no_cuda"): + args.update({"no_cuda": True}) + args.update({"use_ipex": True}) + + # set attr 'deepspeed' + if accelerate_mode == "DEEPSPEED": + args.update({"deepspeed": config["Training"]["deepspeed_config_file"]}) + + # set attr for FSDP + # if accelerate_mode == "FSDP": + # args.updatwe({}) + + # set attr for Intel Gaudi + if device == "hpu": + args.update({"use_habana": True}) + args.update({"use_lazy_mode": config["Training"]["hpu_execution_mode"] == "lazy"}) + args.update({"pipelining_fwd_bwd": True}) + + return cls(**args) + + +def convert_dtype(dtype: str) -> Optional[torch.dtype]: + supported_dtypes = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "no": None, + } + return supported_dtypes[dtype] + + +def load_tokenizer(config: Dict): + if config["General"].get("tokenizer_name") is not None: + tokenizer_name = config["General"].get("tokenizer_name") + else: + tokenizer_name = config["General"]["base_model"] + load_config = config["General"].get("config", {}) + # default padding side is right + padding_side = config["Dataset"].get("padding_side", "right") + # default truncation side is right + truncation_side = config["Dataset"].get("truncation_side", "right") + tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer_name, padding_side=padding_side, truncation_side=truncation_side, **load_config + ) + return tokenizer + + +def load_dataset(config: Dict): + dataset_file = config["Dataset"].get("train_file", None) + if dataset_file is None: + return + + if os.path.exists(dataset_file): + # load from local file + def local_load(name, **load_config): + if os.path.isfile(name): + file = os.path.basename(os.path.abspath(name)) + path = os.path.dirname(os.path.abspath(name)) + dataset = datasets.load_dataset(path, data_files=file, **load_config) + else: + dataset = datasets.load_dataset(name, **load_config) + return dataset["train"] + + train_dataset = local_load(dataset_file) + validation_file = config["Dataset"].get("validation_file", None) + if validation_file is not None: + validation_dataset = local_load(validation_file) + return datasets.DatasetDict({"train": train_dataset, "validation": validation_dataset}) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0: + dataset_dict = train_dataset.train_test_split(test_size=validation_split_percentage / 100) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict + + return datasets.DatasetDict({"train": train_dataset}) + else: + # try to download and load dataset from huggingface.co + load_config = config["General"].get("config", {}) + use_auth_token = load_config.get("use_auth_token", None) + raw_dataset = datasets.load_dataset(dataset_file, use_auth_token=use_auth_token) + + validation_split_percentage = config["Dataset"].get("validation_split_percentage", 0) + if "validation" not in raw_dataset.keys() and ( + validation_split_percentage / 100 > 0.0 and validation_split_percentage / 100 < 1.0 + ): + dataset_dict = raw_dataset["train"].train_test_split(test_size=validation_split_percentage / 100) + dataset_dict["validation"] = dataset_dict["test"] + return dataset_dict + + return raw_dataset + + +def tokenize_dataset(config: Dict, tokenizer, dataset): + group = config["Dataset"].get("group", True) + block_size = config["Dataset"].get("block_size", 512) + tokenizer.pad_token = tokenizer.eos_token + + processor = DataProcessor(config, tokenizer) + + for key in dataset: + prompts = processor.make_prompt(dataset[key]) + dataset[key] = datasets.Dataset.from_dict(prompts) + + column_names = list(dataset["train"].features) + tokenize_fn = ( + processor.tokenize_by_neural_chat + if config["Dataset"].get("data_preprocess_type", "") == "neural_chat" + else processor.tokenize + ) + + tokenized_dataset = dataset.map( + tokenize_fn, + remove_columns=column_names, + batched=True, + load_from_cache_file=False, + desc="Tokenize dataset", + ) + + if group: + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + return result + + tokenized_dataset = tokenized_dataset.map( + group_texts, + batched=True, + load_from_cache_file=False, + desc=f"Grouping texts in chunks of {block_size}", + ) + + return tokenized_dataset + + +def prepare_data_collator(config: Dict, tokenizer): + return transformers.DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 + ) + + +def load_model(config: Dict): + model_name = config["General"]["base_model"] + model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) + model_config = config["General"].get("config", {}) + model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) + + lora_config = config["General"].get("lora_config", None) + if lora_config: + peft_config = LoraConfig(**lora_config) + model = get_peft_model(model, peft_config) + + egc = config["General"].get("enable_gradient_checkpointing", False) + if egc: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + model.config.use_cache = False + + model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"])) + + return model + + +def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator): + device = config["Training"]["device"] + if device in ["cpu", "gpu"]: + from transformers import Trainer, TrainingArguments + + training_args = convert_to_training_args(TrainingArguments, config) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + return training_args, trainer + elif device in ["hpu"]: + from optimum.habana import GaudiConfig + from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments + + # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config + gaudi_config_name = config["General"].get("gaudi_config_name", None) + if gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name) + else: + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + training_args = convert_to_training_args(GaudiTrainingArguments, config) + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + train_dataset=tokenized_dataset["train"], + eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + return training_args, trainer + return None + + +def train_func(config: Dict[str, Any]): + os.chdir(config["cwd"]) + + adapt_transformers_to_device(config) + + set_seed(config) + + tokenizer = load_tokenizer(config) + + dataset = load_dataset(config) + + max_train_samples = config["Dataset"].get("max_train_samples", 0) + if 0 < max_train_samples < len(dataset["train"]): + dataset["train"] = dataset["train"].select(range(max_train_samples)) + + max_eval_samples = config["Dataset"].get("max_eval_samples", 0) + if "validation" in dataset and 0 < max_eval_samples < len(dataset["validation"]): + dataset["validation"] = dataset["validation"].select(range(max_eval_samples)) + + tokenized_dataset = tokenize_dataset(config, tokenizer, dataset) + + data_collator = prepare_data_collator(config, tokenizer) + + model = load_model(config) + + training_args, trainer = get_trainer(config, model, tokenizer, tokenized_dataset, data_collator) + + common.logger.info("train start") + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + common.logger.info("train finish") + + +def get_finetune_config(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--config_file", + type=str, + required=True, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + + # Print help if no arguments were provided + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + config_file = args.config_file + + with open(config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + return finetune_config.dict() + + +def main(external_config=None): + if not external_config: + config = get_finetune_config() + else: + config = external_config + + config["cwd"] = os.getcwd() + + num_training_workers = config["Training"].get("num_training_workers") + resources_per_worker = config["Training"].get("resources_per_worker") + + if num_training_workers > 1 and config["Training"].get("accelerate_mode", None) is None: + config["Training"]["accelerate_mode"] = "DDP" # will use DDP to accelerate if no method specified + + ccl_worker_count = 1 + device = config["Training"]["device"] + if device != "cpu": + ccl_worker_count = num_training_workers + + if not ray.is_initialized(): + runtime_env = { + "env_vars": { + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "CCL_ZE_IPC_EXCHANGE": "sockets", + "CCL_WORKER_COUNT": str(ccl_worker_count), + "CCL_LOG_LEVEL": "info", + "FI_TCP_IFACE": "lo", + "FI_PROVIDER": "tcp", + } + } + + if config["General"]["gpt_base_model"] is True: + runtime_env["pip"] = ["transformers==4.26.0"] + + if device == "gpu": + num_cpus = resources_per_worker["CPU"] * num_training_workers + 1 # additional 1 for head worker + ray.init(num_cpus=num_cpus, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) + + common.logger.info(f"ray available resources = {ray.available_resources()}") + use_gpu = True if device == "gpu" else False + scaling_config = ScalingConfig( + num_workers=num_training_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + placement_strategy="SPREAD", + ) + + # if try to use Intel GPU, convert device to 'xpu' + # due to accelerate internal use 'xpu' represent Intel GPU + if device == "gpu": + from accelerate.utils import is_xpu_available + + if is_xpu_available(): + device = "xpu" + + if config.get("torch_config", None) is None: + backend = None + if device == "cpu" or device == "xpu" or device == "gpu": + backend = "ccl" + elif device == "hpu": + backend = "hccl" + torch_config = common.TorchConfig(backend=backend, device=device) + else: + customer_torch_config = config.get("torch_config") + torch_config = common.TorchConfig(**customer_torch_config, device=device) + + if config.get("failure_config", None) is None: + failure_config = FailureConfig() + else: + customer_failure_config = config.get("failure_config") + failure_config = FailureConfig(**customer_failure_config) + + if config.get("run_config", None) is None: + run_config = RunConfig(failure_config=failure_config) + else: + customer_run_config = config.get("run_config") + if customer_run_config.get("failure_config", None) is None: + customer_run_config["failure_config"] = failure_config + run_config = RunConfig(**customer_run_config) + + trainer = TorchTrainer( + train_func, + train_loop_config=config, + scaling_config=scaling_config, + torch_config=torch_config, + run_config=run_config, + ) + results = trainer.fit() + if external_config is not None: + return results + + +if __name__ == "__main__": + main() diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py new file mode 100644 index 0000000000..391c6e6c89 --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -0,0 +1,156 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright 2023 The LLM-on-Ray Authors. + +from typing import List, Optional, Union + +from pydantic import BaseModel, validator + +PRECISION_BF16 = "bf16" +PRECISION_FP16 = "fp16" +PRECISION_NO = "no" + +DEVICE_CPU = "cpu" +DEVICE_HPU = "hpu" +DEVICE_GPU = "gpu" + +ACCELERATE_STRATEGY_DDP = "DDP" +ACCELERATE_STRATEGY_FSDP = "FSDP" +ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED" + + +class GeneralConfig(BaseModel): + trust_remote_code: bool + use_auth_token: Optional[str] + + +class LoraConfig(BaseModel): + task_type: str + r: int + lora_alpha: int + lora_dropout: float + target_modules: Optional[List[str]] = None + + +class General(BaseModel): + base_model: str + tokenizer_name: Optional[str] = None + gaudi_config_name: Optional[str] = None + gpt_base_model: bool + output_dir: str + report_to: str = "none" + resume_from_checkpoint: Optional[str] = None + save_strategy: str = "no" + config: GeneralConfig + lora_config: Optional[LoraConfig] = None + enable_gradient_checkpointing: bool = False + + @validator("report_to") + def check_report_to(cls, v: str): + assert v in ["none", "tensorboard"] + return v + + +class Dataset(BaseModel): + train_file: str + validation_file: Optional[str] + validation_split_percentage: int + max_length: int = 512 + group: bool = True + block_size: int = 512 + shuffle: bool = False + max_source_length: int = 384 + padding_side: str = "right" + truncation_side: str = "right" + max_seq_length: int = 512 + truncation: bool = True + padding: bool = True + mask_input: bool = True + mask_response: bool = True + data_preprocess_type: str = "neural_chat" + max_train_samples: int = 0 + max_eval_samples: int = 0 + + +class RayResourceConfig(BaseModel): + CPU: int + GPU: int = 0 + HPU: int = 0 + + +class Training(BaseModel): + optimizer: str + batch_size: int + epochs: int + max_train_steps: Optional[int] = None + learning_rate: float + lr_scheduler: str + weight_decay: float + device: str = DEVICE_CPU + hpu_execution_mode: str = "lazy" + num_training_workers: int + resources_per_worker: RayResourceConfig + accelerate_mode: str = ACCELERATE_STRATEGY_DDP + mixed_precision: str = PRECISION_NO + gradient_accumulation_steps: int = 1 + logging_steps: int = 10 + deepspeed_config_file: str = "" + + @validator("device") + def check_device(cls, v: str): + # will convert to lower case + if v: + assert v.lower() in [DEVICE_CPU, DEVICE_GPU, DEVICE_HPU] + return v.lower() + + @validator("hpu_execution_mode") + def check_hpu_execution_mode(cls, v: str): + if v: + assert v in ["lazy", "eager", "eager.compile"] + return v + + @validator("accelerate_mode") + def check_accelerate_mode(cls, v: str): + if v: + assert v in [ + ACCELERATE_STRATEGY_DDP, + ACCELERATE_STRATEGY_FSDP, + ACCELERATE_STRATEGY_DEEPSPEED, + ] + return v + + @validator("mixed_precision") + def check_mixed_precision(cls, v: str): + if v: + assert v in [PRECISION_BF16, PRECISION_FP16, PRECISION_NO] + return v + + @validator("logging_steps") + def check_logging_steps(cls, v: int): + assert v > 0 + return v + + # @model_validator(mode='after') + # def check_device_and_accelerate_mode(self) -> "Training": + # dev = self.device + # res = self.resources_per_worker + # mode = self.accelerate_mode + # if dev == "CPU": + # if res.GPU is not None and res.GPU > 0: + # raise ValueError("Please not specified GPU resource when use CPU only in Ray.") + # if mode != "CPU_DDP": + # raise ValueError("Please specified CPU related accelerate mode when use CPU only in Ray.") + # elif dev == "GPU": + # if res.GPU is None or res.GPU == 0: + # raise ValueError("Please specified GPU resource when use GPU to fine tune in Ray.") + # if mode not in ["GPU_DDP", "GPU_FSDP"]: + # raise ValueError("Please speicifed GPU related accelerate mode when use GPU to fine tune in Ray.") + + # return self + + +class FinetuneConfig(BaseModel): + General: General + Dataset: Dataset + Training: Training diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml new file mode 100644 index 0000000000..d6ae5f34d7 --- /dev/null +++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml @@ -0,0 +1,39 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +General: + base_model: meta-llama/Llama-2-7b-chat-hf + output_dir: "./tmp" + gpt_base_model: false + save_strategy: no + config: + trust_remote_code: false + use_auth_token: null + lora_config: + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - v_proj + enable_gradient_checkpointing: false +Dataset: + train_file: examples/data/sample_finetune_data_small.jsonl + group: false + validation_file: null + validation_split_percentage: 5 +Training: + optimizer: adamw_torch + batch_size: 2 + epochs: 3 + learning_rate: 1.0e-05 + lr_scheduler: linear + weight_decay: 0.0 + mixed_precision: bf16 + device: cpu + num_training_workers: 1 + resources_per_worker: + CPU: 32 + gradient_accumulation_steps: 1 + logging_steps: 10 diff --git a/comps/finetuning/models/mistral-7b-v0.1.yaml b/comps/finetuning/models/mistral-7b-v0.1.yaml new file mode 100644 index 0000000000..4334fa37ea --- /dev/null +++ b/comps/finetuning/models/mistral-7b-v0.1.yaml @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +General: + base_model: mistralai/Mistral-7B-v0.1 + output_dir: "./tmp" + gpt_base_model: false + save_strategy: no + config: + trust_remote_code: false + use_auth_token: null + lora_config: + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - gate_proj + - up_proj + - down_proj + - lm_head + enable_gradient_checkpointing: false +Dataset: + train_file: examples/data/sample_finetune_data_small.jsonl + validation_file: null + validation_split_percentage: 5 +Training: + optimizer: adamw_torch + batch_size: 2 + epochs: 3 + learning_rate: 1.0e-05 + lr_scheduler: linear + weight_decay: 0.0 + mixed_precision: bf16 + device: cpu + num_training_workers: 2 + resources_per_worker: + CPU: 32 + accelerate_mode: DDP + gradient_accumulation_steps: 1 + logging_steps: 10 diff --git a/comps/finetuning/requirements.txt b/comps/finetuning/requirements.txt new file mode 100644 index 0000000000..4255a37165 --- /dev/null +++ b/comps/finetuning/requirements.txt @@ -0,0 +1,19 @@ +aiohttp +datasets +docarray +fastapi +httpx +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +peft +prometheus-fastapi-instrumentator +pydantic==2.8.2 +pydantic_yaml +python-multipart +pyyaml +ray[all] +requests +shortuuid +transformers +uvicorn