diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index e1276716c..689fff9dd 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -3,7 +3,7 @@ import asyncio import multiprocessing -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType @@ -154,6 +154,7 @@ def register_microservice( output_datatype: Type[Any] = TextDoc, provider: Optional[str] = None, provider_endpoint: Optional[str] = None, + methods: List[str] = ["POST"], ): def decorator(func): if name not in opea_microservices: @@ -173,7 +174,7 @@ def decorator(func): provider_endpoint=provider_endpoint, ) opea_microservices[name] = micro_service - opea_microservices[name].app.router.add_api_route(endpoint, func, methods=["POST"]) + opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) return func return decorator diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 382982d27..1a1901d5d 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -463,3 +463,225 @@ def check_requests(request) -> Optional[JSONResponse]: ) return None + + +class Hyperparameters(BaseModel): + batch_size: Optional[Union[Literal["auto"], int]] = "auto" + """Number of examples in each batch. + + A larger batch size means that model parameters are updated less frequently, but with lower variance. + """ + + learning_rate_multiplier: Optional[Union[Literal["auto"], float]] = "auto" + """Scaling factor for the learning rate. + + A smaller learning rate may be useful to avoid overfitting. + """ + + n_epochs: Optional[Union[Literal["auto"], int]] = "auto" + """The number of epochs to train the model for. + + An epoch refers to one full cycle through the training dataset. "auto" decides + the optimal number of epochs based on the size of the dataset. If setting the + number manually, we support any number between 1 and 50 epochs. + """ + + +class FineTuningJobWandbIntegration(BaseModel): + project: str + """The name of the project that the new run will be created under.""" + + entity: Optional[str] = None + """The entity to use for the run. + + This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered + WandB API key is used. + """ + + name: Optional[str] = None + """A display name to set for the run. + + If not set, we will use the Job ID as the name. + """ + + tags: Optional[List[str]] = None + """A list of tags to be attached to the newly created run. + + These tags are passed through directly to WandB. Some default tags are generated + by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + """ + + +class FineTuningJobWandbIntegrationObject(BaseModel): + type: Literal["wandb"] + """The type of the integration being enabled for the fine-tuning job.""" + + wandb: FineTuningJobWandbIntegration + """The settings for your integration with Weights and Biases. + + This payload specifies the project that metrics will be sent to. Optionally, you + can set an explicit display name for your run, add tags to your run, and set a + default entity (team, username, etc) to be associated with your run. + """ + + +class FineTuningJobsRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/create + model: str + """The name of the model to fine-tune.""" + + training_file: str + """The ID of an uploaded file that contains training data.""" + + hyperparameters: Optional[Hyperparameters] = None + """The hyperparameters used for the fine-tuning job.""" + + suffix: Optional[str] = None + """A string of up to 64 characters that will be added to your fine-tuned model name.""" + + validation_file: Optional[str] = None + """The ID of an uploaded file that contains validation data.""" + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for your fine-tuning job.""" + + seed: Optional[str] = None + + +class Error(BaseModel): + code: str + """A machine-readable error code.""" + + message: str + """A human-readable error message.""" + + param: Optional[str] = None + """The parameter that was invalid, usually `training_file` or `validation_file`. + + This field will be null if the failure was not parameter-specific. + """ + + +class FineTuningJob(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/object + id: str + """The object identifier, which can be referenced in the API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the fine-tuning job was created.""" + + error: Optional[Error] = None + """For fine-tuning jobs that have `failed`, this will contain more information on + the cause of the failure.""" + + fine_tuned_model: Optional[str] = None + """The name of the fine-tuned model that is being created. + + The value will be null if the fine-tuning job is still running. + """ + + finished_at: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job was finished. + + The value will be null if the fine-tuning job is still running. + """ + + hyperparameters: Hyperparameters + """The hyperparameters used for the fine-tuning job. + + See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) + for more details. + """ + + model: str + """The base model that is being fine-tuned.""" + + object: Literal["fine_tuning.job"] = "fine_tuning.job" + """The object type, which is always "fine_tuning.job".""" + + organization_id: Optional[str] = None + """The organization that owns the fine-tuning job.""" + + result_files: List[str] = None + """The compiled results file ID(s) for the fine-tuning job. + + You can retrieve the results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + status: Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"] + """The current status of the fine-tuning job, which can be either + `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.""" + + trained_tokens: Optional[int] = None + """The total number of billable tokens processed by this fine-tuning job. + + The value will be null if the fine-tuning job is still running. + """ + + training_file: str + """The file ID used for training. + + You can retrieve the training data with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + validation_file: Optional[str] = None + """The file ID used for validation. + + You can retrieve the validation results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for this fine-tuning job.""" + + seed: Optional[int] = None + """The seed used for the fine-tuning job.""" + + estimated_finish: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job is estimated to + finish. + + The value will be null if the fine-tuning job is not running. + """ + + +class FineTuningJobIDRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/retrieve + # https://platform.openai.com/docs/api-reference/fine-tuning/cancel + fine_tuning_job_id: str + """The ID of the fine-tuning job.""" + + +class FineTuningJobListRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + after: Optional[str] = None + """Identifier for the last job from the previous pagination request.""" + + limit: Optional[int] = 20 + """Number of fine-tuning jobs to retrieve.""" + + +class FineTuningJobList(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + object: str = "list" + """The object type, which is always "list". + + This indicates that the returned data is a list of fine-tuning jobs. + """ + + data: List[FineTuningJob] + """A list containing FineTuningJob objects.""" + + has_more: bool + """Indicates whether there are more fine-tuning jobs beyond the current list. + + If true, additional requests can be made to retrieve more jobs. + """ diff --git a/comps/finetuning/envs.py b/comps/finetuning/envs.py new file mode 100644 index 000000000..0d4d41cd5 --- /dev/null +++ b/comps/finetuning/envs.py @@ -0,0 +1,21 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +MODEL_CONFIG_FILE_MAP = { + "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", + "mistralai/Mistral-7B-v0.1": "./models/mistral-7b-v0.1.yaml", +} + +DATASET_BASE_PATH = "datasets" +JOBS_PATH = "jobs" +if not os.path.exists(DATASET_BASE_PATH): + os.path.mkdir(DATASET_BASE_PATH) + +if not os.path.exists(JOBS_PATH): + os.path.mkdir(JOBS_PATH) + +CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs + +ray_client = None diff --git a/comps/finetuning/finetune_config.py b/comps/finetuning/finetune_config.py new file mode 100644 index 000000000..c53b36131 --- /dev/null +++ b/comps/finetuning/finetune_config.py @@ -0,0 +1,156 @@ +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional + +from pydantic import BaseModel, validator + +PRECISION_BF16 = "bf16" +PRECISION_FP16 = "fp16" +PRECISION_NO = "no" + +DEVICE_CPU = "cpu" +DEVICE_HPU = "hpu" +DEVICE_GPU = "gpu" + +ACCELERATE_STRATEGY_DDP = "DDP" +ACCELERATE_STRATEGY_FSDP = "FSDP" +ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED" + + +class GeneralConfig(BaseModel): + trust_remote_code: bool + use_auth_token: Optional[str] + + +class LoraConfig(BaseModel): + task_type: str + r: int + lora_alpha: int + lora_dropout: float + target_modules: Optional[List[str]] = None + + +class DeltatunerConfig(BaseModel): + algo: str + denas: bool + best_model_structure: str + + +class General(BaseModel): + base_model: str + tokenizer_name: Optional[str] = None + gaudi_config_name: Optional[str] = None + gpt_base_model: bool + output_dir: str + resume_from_checkpoint: Optional[str] = None + save_strategy: str = "no" + config: GeneralConfig + lora_config: Optional[LoraConfig] = None + deltatuner_config: Optional[DeltatunerConfig] = None + enable_gradient_checkpointing: bool = False + + +class Dataset(BaseModel): + train_file: str + validation_file: Optional[str] + validation_split_percentage: int + max_length: int = 512 + group: bool = True + block_size: int = 512 + shuffle: bool = False + + +class RayResourceConfig(BaseModel): + CPU: int + GPU: int = 0 + HPU: int = 0 + + +class Training(BaseModel): + optimizer: str + batch_size: int + epochs: int + max_train_steps: Optional[int] = None + learning_rate: float + lr_scheduler: str + weight_decay: float + device: str = DEVICE_CPU + hpu_execution_mode: str = "lazy" + num_training_workers: int + resources_per_worker: RayResourceConfig + accelerate_mode: str = ACCELERATE_STRATEGY_DDP + mixed_precision: str = PRECISION_NO + gradient_accumulation_steps: int = 1 + logging_steps: int = 10 + deepspeed_config_file: str = "" + + @validator("device") + def check_device(cls, v: str): + # will convert to lower case + if v: + assert v.lower() in [DEVICE_CPU, DEVICE_GPU, DEVICE_HPU] + return v.lower() + + @validator("hpu_execution_mode") + def check_hpu_execution_mode(cls, v: str): + if v: + assert v in ["lazy", "eager", "eager.compile"] + return v + + @validator("accelerate_mode") + def check_accelerate_mode(cls, v: str): + if v: + assert v in [ + ACCELERATE_STRATEGY_DDP, + ACCELERATE_STRATEGY_FSDP, + ACCELERATE_STRATEGY_DEEPSPEED, + ] + return v + + @validator("mixed_precision") + def check_mixed_precision(cls, v: str): + if v: + assert v in [PRECISION_BF16, PRECISION_FP16, PRECISION_NO] + return v + + @validator("logging_steps") + def check_logging_steps(cls, v: int): + assert v > 0 + return v + + # @model_validator(mode='after') + # def check_device_and_accelerate_mode(self) -> "Training": + # dev = self.device + # res = self.resources_per_worker + # mode = self.accelerate_mode + # if dev == "CPU": + # if res.GPU is not None and res.GPU > 0: + # raise ValueError("Please not specified GPU resource when use CPU only in Ray.") + # if mode != "CPU_DDP": + # raise ValueError("Please specified CPU related accelerate mode when use CPU only in Ray.") + # elif dev == "GPU": + # if res.GPU is None or res.GPU == 0: + # raise ValueError("Please specified GPU resource when use GPU to fine tune in Ray.") + # if mode not in ["GPU_DDP", "GPU_FSDP"]: + # raise ValueError("Please speicifed GPU related accelerate mode when use GPU to fine tune in Ray.") + + # return self + + +class FinetuneConfig(BaseModel): + General: General + Dataset: Dataset + Training: Training diff --git a/comps/finetuning/finetuning.py b/comps/finetuning/finetuning.py new file mode 100644 index 000000000..8e79e5642 --- /dev/null +++ b/comps/finetuning/finetuning.py @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from handlers import ( + handle_cancel_finetuning_job, + handle_create_finetuning_jobs, + handle_list_finetuning_jobs, + handle_retrieve_finetuning_job, +) + +from comps import opea_microservices, register_microservice +from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobListRequest, FineTuningJobsRequest + + +@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001) +def create_finetuning_jobs(request: FineTuningJobsRequest): + return handle_create_finetuning_jobs(request) + + +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001, methods=["GET"] +) +def list_finetuning_jobs(request: FineTuningJobListRequest): + return handle_list_finetuning_jobs(request) + + +@register_microservice( + name="opea_service@finetuning", + endpoint="/v1/fine_tuning/jobs/{fine_tuning_job_id}", + host="0.0.0.0", + port=8001, + methods=["GET"], +) +def retrieve_finetuning_job(request: FineTuningJobIDRequest): + job = handle_retrieve_finetuning_job(request) + return job + + +@register_microservice( + name="opea_service@finetuning", + endpoint="/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", + host="0.0.0.0", + port=8001, +) +def cancel_finetuning_job(request: FineTuningJobIDRequest): + job = handle_cancel_finetuning_job(request) + return job + + +if __name__ == "__main__": + opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py new file mode 100644 index 000000000..7721b52bc --- /dev/null +++ b/comps/finetuning/handlers.py @@ -0,0 +1,133 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +import time +import uuid +from typing import Any, Dict, List, Set + +from envs import CHECK_JOB_STATUS_INTERVAL, DATASET_BASE_PATH, MODEL_CONFIG_FILE_MAP, ray_client +from finetune_config import FinetuneConfig +from pydantic_yaml import parse_yaml_raw_as, to_yaml_file +from ray.job_submission import JobSubmissionClient + +from comps.cores.proto.api_protocol import ( + FineTuningJob, + FineTuningJobIDRequest, + FineTuningJobList, + FineTuningJobsRequest, +) + +FineTuningJobID = str +running_finetuning_jobs: Dict[FineTuningJobID, FineTuningJob] = {} +finetuning_job_to_ray_job: Dict[FineTuningJobID, str] = {} + + +def handle_create_finetuning_jobs(request: FineTuningJobsRequest): + base_model = request.model + train_file = request.training_file + train_file_path = os.path.join(DATASET_BASE_PATH, train_file) + + model_config_file = MODEL_CONFIG_FILE_MAP.get(base_model) + + if not model_config_file: + raise HTTPException(status_code=404, detail=f"Base model '{base_model}' not supported!") + + if not os.path.exists(train_file_path): + raise HTTPException(status_code=404, detail=f"Training file '{train_file}' not found!") + + with open(model_config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + + finetune_config.Dataset.train_file = train_file_path + + if request.hyperparameters is not None: + if request.hyperparameters.epochs != "auto": + finetune_config.Training.epochs = request.hyperparameters.epochs + + if request.hyperparameters.batch_size != "auto": + finetune_config.Training.batch_size = request.hyperparameters.batch_size + + if request.hyperparameters.learning_rate_multiplier != "auto": + finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier + + job = FineTuningJob( + id=f"ft-job-{uuid.uuid4()}", + model=base_model, + created_at=int(time.time()), + training_file=train_file, + hyperparameters={ + "n_epochs": finetune_config.Training.epochs, + "batch_size": finetune_config.Training.batch_size, + "learning_rate_multiplier": finetune_config.Training.learning_rate, + }, + status="running", + # TODO: Add seed in finetune config + seed=random.randint(0, 1000), + ) + + finetune_config_file = f"jobs/{job.id}.yaml" + to_yaml_file(finetune_config_file, finetune_config) + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + + ray_job_id = ray_client.submit_job( + # Entrypoint shell command to execute + entrypoint=f"python finetune_runner.py --config_file {finetune_config_file}", + # Path to the local directory that contains the script.py file + runtime_env={"working_dir": "./"}, + ) + print(f"Submitted Ray job: {ray_job_id} ...") + + running_finetuning_jobs[job.id] = job + finetuning_job_to_ray_job[job.id] = ray_job_id + + return job + + +def handle_list_finetuning_jobs(): + finetuning_jobs_list = FineTuningJobList(data=list(running_finetuning_jobs.values()), has_more=False) + + return finetuning_jobs_list + + +def handle_retrieve_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + job = running_finetuning_jobs.get(fine_tuning_job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + return job + + +def handle_cancel_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) + if ray_job_id is None: + raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") + + global ray_client + ray_client = JobSubmissionClient() if ray_client is None else ray_client + ray_client.stop_job(ray_job_id) + + job = running_finetuning_jobs.get(fine_tuning_job_id) + + if job is None: + raise HTTPException(status_code=404, detail=f"Job with ID '{fine_tuning_job_id}' not found in running jobs!") + + # Check the job status before attempting to cancel + if job.status == "running": + # Stop the Ray job + ray_client.stop_job(ray_job_id) + # Update job status to cancelled + job.status = "cancelled" + else: + # If the job is not running, return a message indicating it cannot be cancelled + raise HTTPException( + status_code=400, detail=f"Job with ID '{fine_tuning_job_id}' is not running and cannot be cancelled." + ) + + return job diff --git a/comps/finetuning/models/llama-2-7b-chat-hf.yaml b/comps/finetuning/models/llama-2-7b-chat-hf.yaml new file mode 100644 index 000000000..ab62383d2 --- /dev/null +++ b/comps/finetuning/models/llama-2-7b-chat-hf.yaml @@ -0,0 +1,40 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +General: + base_model: meta-llama/Llama-2-7b-chat-hf + gpt_base_model: false + output_dir: /tmp/llm-ray/output + save_strategy: no + config: + trust_remote_code: false + use_auth_token: null + lora_config: + task_type: CAUSAL_LM + r: 8 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: + - q_proj + - v_proj + enable_gradient_checkpointing: false +Dataset: + train_file: examples/data/sample_finetune_data_small.jsonl + group: false + validation_file: null + validation_split_percentage: 5 +Training: + optimizer: adamw_torch + batch_size: 2 + epochs: 3 + learning_rate: 1.0e-05 + lr_scheduler: linear + weight_decay: 0.0 + mixed_precision: bf16 + device: cpu + num_training_workers: 2 + resources_per_worker: + CPU: 32 + accelerate_mode: DDP + gradient_accumulation_steps: 1 + logging_steps: 10