diff --git a/comps/finetuning/llm_on_ray/common/common.py b/comps/finetuning/llm_on_ray/common/common.py new file mode 100644 index 000000000..ed31dc6cf --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/common.py @@ -0,0 +1,38 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import glob +import importlib + +from llm_on_ray.common.logging import logger + + +def import_all_modules(basedir, prefix=None): + all_py_files = glob.glob(basedir + "/*.py") + modules = [os.path.basename(f) for f in all_py_files] + + for module in modules: + if not module.startswith("_"): + module = module.rstrip(".py") + if prefix is None: + module_name = module + else: + module_name = f"{prefix}.{module}" + try: + importlib.import_module(module_name) + except Exception: + logger.warning(f"import {module_name} error", exc_info=True) diff --git a/comps/finetuning/llm_on_ray/common/logging.py b/comps/finetuning/llm_on_ray/common/logging.py new file mode 100644 index 000000000..c4b33c3d7 --- /dev/null +++ b/comps/finetuning/llm_on_ray/common/logging.py @@ -0,0 +1,67 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import logging.config +import traceback +import functools + +__all__ = ["logger", "get_logger"] + +use_accelerate_log = False +logger_name = "common" + +logging_config = { + "version": 1, + "loggers": { + "root": {"level": "INFO", "handlers": ["consoleHandler"]}, + "common": { + "level": "INFO", + "handlers": ["consoleHandler"], + "qualname": "common", + "propagate": 0, + }, + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "standardFormatter", + }, + }, + "formatters": { + "standardFormatter": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + "datefmt": "", + } + }, +} + +if logging_config is not None: + try: + logging.config.dictConfig(logging_config) + except Exception: + traceback.print_exc() + exit(1) + +if use_accelerate_log: + import accelerate + + get_logger = functools.partial(accelerate.logging.get_logger, name=logger_name) +else: + get_logger = functools.partial(logging.getLogger, name=logger_name) + +logger = get_logger() diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py new file mode 100644 index 000000000..f36d6b7ed --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -0,0 +1,428 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#!/usr/bin/env python + +import os +import argparse +import sys +from typing import Any, Dict, Union, Optional + +import torch + +import transformers + +import ray +from ray.train.torch import TorchTrainer +from ray.air.config import ScalingConfig +from ray.air import RunConfig, FailureConfig + +from pydantic_yaml import parse_yaml_raw_as + +from llm_on_ray import common +from llm_on_ray.finetune.finetune_config import FinetuneConfig +from importlib import util + +use_habana = False +if util.find_spec("habana_frameworks") is not None: + from optimum.habana.utils import set_seed + + use_habana = True +else: + from accelerate.utils import set_seed, is_xpu_available + + use_habana = False + + +def get_accelerate_environment_variable(config: Dict[str, Any]) -> dict: + device = config["Training"]["device"] + accelerate_mode = config["Training"]["accelerate_mode"] + mixed_precision = config["Training"]["mixed_precision"] + mode_env_vars = { + "cpu": { + "DDP": { + "ACCELERATE_USE_CPU": "true", + "ACCELERATE_USE_IPEX": "true", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + } + }, + "gpu": { + "DDP": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "true", + "ACCELERATE_USE_IPEX": "true", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, + "FSDP": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "true", + "ACCELERATE_USE_IPEX": "true", + "ACCELERATE_USE_FSDP": "true", + "FSDP_SHARDING_STRATEGY": "1", + "FSDP_OFFLOAD_PARAMS": "false", + "FSDP_AUTO_WRAP_POLICY": "NO_WRAP", + "FSDP_BACKWARD_PREFETCH": "BACKWARD_PRE", + "FSDP_STATE_DICT_TYPE": "SHARDED_STATE_DICT", + "FSDP_FORWARD_PREFETCH": "false", + "FSDP_USE_ORIG_PARAMS": "false", + "FSDP_SYNC_MODULE_STATES": "true", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, + "DEEPSPEED": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "true", + "ACCELERATE_USE_IPEX": "true", + "ACCELERATE_USE_DEEPSPEED": "true", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, + }, + "hpu": { + "DDP": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "false", + "ACCELERATE_USE_IPEX": "false", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, + "DEEPSPEED": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "false", + "ACCELERATE_USE_IPEX": "false", + "ACCELERATE_USE_DEEPSPEED": "true", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, + }, + } + if device not in mode_env_vars or accelerate_mode not in mode_env_vars[device]: + supported_mode_info = "" + for k in mode_env_vars.keys(): + supported_mode_info += k + ":[" + for m in mode_env_vars[k]: + supported_mode_info += m + "," + supported_mode_info = supported_mode_info[:-1] + supported_mode_info += "]," + supported_mode_info = supported_mode_info[:-1] + + raise ValueError( + f"device {device} and accelerate mode {accelerate_mode} not supported. supported device and accelerate mode is {supported_mode_info}" + ) + return mode_env_vars[device][accelerate_mode] + + +def convert_to_training_args(cls, config): + device = config["Training"]["device"] + accelerate_mode = config["Training"]["accelerate_mode"] + save_strategy = config["General"]["save_strategy"] + + args = { + "output_dir": config["General"]["output_dir"], + "resume_from_checkpoint": config["General"]["resume_from_checkpoint"], + "gradient_checkpointing": config["General"]["enable_gradient_checkpointing"], + "save_strategy": save_strategy if save_strategy != "False" else "no", + "bf16": config["Training"]["mixed_precision"] == "bf16", + "num_train_epochs": config["Training"]["epochs"], + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "optim": config["Training"]["optimizer"], + "learning_rate": config["Training"]["learning_rate"], + "logging_steps": config["Training"]["logging_steps"], + "lr_scheduler_type": config["Training"]["lr_scheduler"], + "weight_decay": config["Training"]["weight_decay"], + "gradient_accumulation_steps": config["Training"]["gradient_accumulation_steps"], + } + + # set attr max_steps + if config["Training"]["max_train_steps"] is not None: + args.update({"max_steps": config["Training"]["max_train_steps"]}) + + # set attr for device cpu + if device == "cpu": + if hasattr(cls, "use_cpu"): + args.update({"use_cpu": True}) + if hasattr(cls, "no_cuda"): + args.update({"no_cuda": True}) + args.update({"use_ipex": True}) + + # set attr 'deepspeed' + if accelerate_mode == "DEEPSPEED": + args.update({"deepspeed": config["Training"]["deepspeed_config_file"]}) + + # set attr for FSDP + # if accelerate_mode == "FSDP": + # args.updatwe({}) + + # set attr for Intel Gaudi + if device == "hpu": + args.update({"use_habana": True}) + args.update({"use_lazy_mode": config["Training"]["hpu_execution_mode"] == "lazy"}) + args.update({"pipelining_fwd_bwd": True}) + + return cls(**args) + + +def get_device_environment_variable(device): + if device == "hpu": + return { + "HABANA_VISIBLE_DEVICES": "all", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES": "true", + } + return {} + + +def convert_dtype(dtype: str) -> Optional[torch.dtype]: + supported_dtypes = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "no": None, + } + return supported_dtypes[dtype] + + +def train_func(config: Dict[str, Any]): + os.chdir(config["cwd"]) + + device = config["Training"]["device"] + + base_model = config["General"]["base_model"] + if config["General"].get("tokenizer_name") is not None: + tokenizer_name = config["General"].get("tokenizer_name") + else: + tokenizer_name = base_model + dataset_file = config["Dataset"]["train_file"] + + seed = config["Training"].get("seed") + if seed is not None: + set_seed(seed) + + tokenizer = common.tokenizer.Tokenizer.registory.get("HuggingFaceTokenizer")()( + config={ + "name": tokenizer_name, + "config": config["General"]["config"], + } + ) + + datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()( + config={ + "name": dataset_file, + "validation_file": config["Dataset"]["validation_file"], + "validation_split_percentage": config["Dataset"]["validation_split_percentage"], + } + ) + + dataprocesser = common.dataprocesser.DataProcesser.registory.get("GeneralProcesser")( + config={ + "per_device_train_batch_size": config["Training"]["batch_size"], + "per_device_eval_batch_size": config["Training"]["batch_size"], + "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), + "max_length": config["Dataset"].get("max_length", 512), + "group": config["Dataset"].get("group", True), + "block_size": config["Dataset"].get("block_size", 512), + "shuffle": config["Dataset"].get("shuffle", False), + } + ) + tokenized_datasets = dataprocesser.tokenize_dataset(tokenizer, datasets) + + model = common.model.Model.registory.get("HuggingFaceModelForCausalLM")()( + config={ + "name": base_model, + "dtype": convert_dtype(config["Training"].get("mixed_precision", "no")), + "device": torch.device(device), + "config": config["General"]["config"], + "enable_gradient_checkpointing": config["General"].get( + "enable_gradient_checkpointing", False + ), + "lora_config": config["General"].get("lora_config", None), + } + ) + + data_collator = common.dataprocesser.general_processer.DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 + ) + + + if device in ["cpu", "gpu"]: + from transformers import Trainer, TrainingArguments + + training_args = convert_to_training_args(TrainingArguments, config) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["validation"] + if tokenized_datasets.get("validation") is not None + else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + common.logger.info("train start") + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + common.logger.info("train finish") + elif device in ["hpu"]: + from optimum.habana.transformers import GaudiTrainer + from optimum.habana.transformers import GaudiTrainingArguments + from optimum.habana import GaudiConfig + + # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config + if config["general"].get("gaudi_config_name") is not None: + gaudi_config = GaudiConfig.from_pretrained( + config["general"].get("gaudi_config_name"), + ) + else: + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + training_args = convert_to_training_args(GaudiTrainingArguments, config) + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["validation"] + if tokenized_datasets.get("validation") is not None + else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + common.logger.info("train start") + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + common.logger.info("train finish") + + +def get_finetune_config(): + parser = argparse.ArgumentParser( + description="Finetune a transformers model on a causal language modeling task" + ) + parser.add_argument( + "--config_file", + type=str, + required=True, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + + # Print help if no arguments were provided + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + config_file = args.config_file + + with open(config_file) as f: + finetune_config = parse_yaml_raw_as(FinetuneConfig, f) + return finetune_config.dict() + + +def main(external_config=None): + if not external_config: + config = get_finetune_config() + else: + config = external_config + + config["cwd"] = os.getcwd() + + num_training_workers = config["Training"].get("num_training_workers") + resources_per_worker = config["Training"].get("resources_per_worker") + + if config["Training"].get("accelerate_mode", None) is None: + config["Training"][ + "accelerate_mode" + ] = "DDP" # will use DDP to accelerate if no method specified + + ccl_worker_count = 1 + device = config["Training"]["device"] + if device != "cpu": + ccl_worker_count = num_training_workers + + if not ray.is_initialized(): + runtime_env = { + "env_vars": { + "OMP_NUM_THREADS": str(resources_per_worker["CPU"]), + "CCL_ZE_IPC_EXCHANGE": "sockets", + "CCL_WORKER_COUNT": str(ccl_worker_count), + "CCL_LOG_LEVEL": "info", + "WORLD_SIZE": str(num_training_workers), + "FI_PROVIDER": "tcp", + } + } + + if config["General"]["gpt_base_model"] is True: + runtime_env["pip"] = ["transformers==4.26.0"] + + if device == "gpu": + num_cpus = ( + resources_per_worker["CPU"] * num_training_workers + 1 + ) # additional 1 for head worker + ray.init(num_cpus=num_cpus, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) + + common.logger.info(f"ray available resources = {ray.available_resources()}") + use_gpu = True if device == "gpu" else False + scaling_config = ScalingConfig( + num_workers=num_training_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + placement_strategy="SPREAD", + ) + + # if try to use Intel GPU, convert device to 'xpu' + # due to accelerate internal use 'xpu' represent Intel GPU + if device == "gpu" and is_xpu_available(): + device = "xpu" + + if config.get("torch_config", None) is None: + backend = None + if device == "cpu" or device == "xpu" or device == "gpu": + backend = "ccl" + elif device == "hpu": + backend = "hccl" + torch_config = common.TorchConfig(backend=backend, device=device) + else: + customer_torch_config = config.get("torch_config") + torch_config = common.TorchConfig(**customer_torch_config, device=device) + + if config.get("failure_config", None) is None: + failure_config = FailureConfig() + else: + customer_failure_config = config.get("failure_config") + failure_config = FailureConfig(**customer_failure_config) + + if config.get("run_config", None) is None: + run_config = RunConfig(failure_config=failure_config) + else: + customer_run_config = config.get("run_config") + if customer_run_config.get("failure_config", None) is None: + customer_run_config["failure_config"] = failure_config + run_config = RunConfig(**customer_run_config) + + trainer = TorchTrainer( + train_func, + train_loop_config=config, + scaling_config=scaling_config, + torch_config=torch_config, + run_config=run_config, + ) + results = trainer.fit() + if external_config is not None: + return results + + +if __name__ == "__main__": + main() diff --git a/comps/finetuning/llm_on_ray/finetune/finetune_config.py b/comps/finetuning/llm_on_ray/finetune/finetune_config.py new file mode 100644 index 000000000..e78600a6d --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/finetune_config.py @@ -0,0 +1,157 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pydantic import BaseModel, validator +from typing import Optional, List + + +PRECISION_BF16 = "bf16" +PRECISION_FP16 = "fp16" +PRECISION_NO = "no" + +DEVICE_CPU = "cpu" +DEVICE_HPU = "hpu" +DEVICE_GPU = "gpu" + +ACCELERATE_STRATEGY_DDP = "DDP" +ACCELERATE_STRATEGY_FSDP = "FSDP" +ACCELERATE_STRATEGY_DEEPSPEED = "DEEPSPEED" + + +class GeneralConfig(BaseModel): + trust_remote_code: bool + use_auth_token: Optional[str] + + +class LoraConfig(BaseModel): + task_type: str + r: int + lora_alpha: int + lora_dropout: float + target_modules: Optional[List[str]] = None + + +class DeltatunerConfig(BaseModel): + algo: str + denas: bool + best_model_structure: str + + +class General(BaseModel): + base_model: str + tokenizer_name: Optional[str] = None + gaudi_config_name: Optional[str] = None + gpt_base_model: bool + output_dir: str + resume_from_checkpoint: Optional[str] = None + save_strategy: str = "no" + config: GeneralConfig + lora_config: Optional[LoraConfig] = None + deltatuner_config: Optional[DeltatunerConfig] = None + enable_gradient_checkpointing: bool = False + + +class Dataset(BaseModel): + train_file: str + validation_file: Optional[str] + validation_split_percentage: int + max_length: int = 512 + group: bool = True + block_size: int = 512 + shuffle: bool = False + + +class RayResourceConfig(BaseModel): + CPU: int + GPU: int = 0 + HPU: int = 0 + + +class Training(BaseModel): + optimizer: str + batch_size: int + epochs: int + max_train_steps: Optional[int] = None + learning_rate: float + lr_scheduler: str + weight_decay: float + device: str = DEVICE_CPU + hpu_execution_mode: str = "lazy" + num_training_workers: int + resources_per_worker: RayResourceConfig + accelerate_mode: str = ACCELERATE_STRATEGY_DDP + mixed_precision: str = PRECISION_NO + gradient_accumulation_steps: int = 1 + logging_steps: int = 10 + deepspeed_config_file: str = "" + + @validator("device") + def check_device(cls, v: str): + # will convert to lower case + if v: + assert v.lower() in [DEVICE_CPU, DEVICE_GPU, DEVICE_HPU] + return v.lower() + + @validator("hpu_execution_mode") + def check_hpu_execution_mode(cls, v: str): + if v: + assert v in ["lazy", "eager", "eager.compile"] + return v + + @validator("accelerate_mode") + def check_accelerate_mode(cls, v: str): + if v: + assert v in [ + ACCELERATE_STRATEGY_DDP, + ACCELERATE_STRATEGY_FSDP, + ACCELERATE_STRATEGY_DEEPSPEED, + ] + return v + + @validator("mixed_precision") + def check_mixed_precision(cls, v: str): + if v: + assert v in [PRECISION_BF16, PRECISION_FP16, PRECISION_NO] + return v + + @validator("logging_steps") + def check_logging_steps(cls, v: int): + assert v > 0 + return v + + # @model_validator(mode='after') + # def check_device_and_accelerate_mode(self) -> "Training": + # dev = self.device + # res = self.resources_per_worker + # mode = self.accelerate_mode + # if dev == "CPU": + # if res.GPU is not None and res.GPU > 0: + # raise ValueError("Please not specified GPU resource when use CPU only in Ray.") + # if mode != "CPU_DDP": + # raise ValueError("Please specified CPU related accelerate mode when use CPU only in Ray.") + # elif dev == "GPU": + # if res.GPU is None or res.GPU == 0: + # raise ValueError("Please specified GPU resource when use GPU to fine tune in Ray.") + # if mode not in ["GPU_DDP", "GPU_FSDP"]: + # raise ValueError("Please speicifed GPU related accelerate mode when use GPU to fine tune in Ray.") + + # return self + + +class FinetuneConfig(BaseModel): + General: General + Dataset: Dataset + Training: Training