Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Wu, Xiaochang <[email protected]>
  • Loading branch information
xwu99 committed Jun 24, 2024
1 parent 6f18486 commit 4aa0840
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 38 deletions.
2 changes: 1 addition & 1 deletion comps/finetuning/llm_on_ray/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# limitations under the License.
#

import os
import glob
import importlib
import os

from llm_on_ray.common.logging import logger

Expand Down
2 changes: 1 addition & 1 deletion comps/finetuning/llm_on_ray/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
#

import functools
import logging
import logging.config
import traceback
import functools

__all__ = ["logger", "get_logger"]

Expand Down
50 changes: 16 additions & 34 deletions comps/finetuning/llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,29 @@

#!/usr/bin/env python

import os
import argparse
import os
import sys
from typing import Any, Dict, Union, Optional
from importlib import util
from typing import Any, Dict, Optional, Union

import ray
import torch

import transformers

import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air import RunConfig, FailureConfig

from pydantic_yaml import parse_yaml_raw_as

from llm_on_ray import common
from llm_on_ray.finetune.finetune_config import FinetuneConfig
from importlib import util
from pydantic_yaml import parse_yaml_raw_as
from ray.air import FailureConfig, RunConfig
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer

use_habana = False
if util.find_spec("habana_frameworks") is not None:
from optimum.habana.utils import set_seed

use_habana = True
else:
from accelerate.utils import set_seed, is_xpu_available
from accelerate.utils import is_xpu_available, set_seed

use_habana = False

Expand Down Expand Up @@ -240,9 +236,7 @@ def train_func(config: Dict[str, Any]):
"dtype": convert_dtype(config["Training"].get("mixed_precision", "no")),
"device": torch.device(device),
"config": config["General"]["config"],
"enable_gradient_checkpointing": config["General"].get(
"enable_gradient_checkpointing", False
),
"enable_gradient_checkpointing": config["General"].get("enable_gradient_checkpointing", False),
"lora_config": config["General"].get("lora_config", None),
}
)
Expand All @@ -251,7 +245,6 @@ def train_func(config: Dict[str, Any]):
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)


if device in ["cpu", "gpu"]:
from transformers import Trainer, TrainingArguments

Expand All @@ -260,9 +253,7 @@ def train_func(config: Dict[str, Any]):
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"]
if tokenized_datasets.get("validation") is not None
else None,
eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand All @@ -272,9 +263,8 @@ def train_func(config: Dict[str, Any]):
trainer.save_model()
common.logger.info("train finish")
elif device in ["hpu"]:
from optimum.habana.transformers import GaudiTrainer
from optimum.habana.transformers import GaudiTrainingArguments
from optimum.habana import GaudiConfig
from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments

# If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config
if config["general"].get("gaudi_config_name") is not None:
Expand All @@ -291,9 +281,7 @@ def train_func(config: Dict[str, Any]):
args=training_args,
gaudi_config=gaudi_config,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"]
if tokenized_datasets.get("validation") is not None
else None,
eval_dataset=tokenized_datasets["validation"] if tokenized_datasets.get("validation") is not None else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
Expand All @@ -305,9 +293,7 @@ def train_func(config: Dict[str, Any]):


def get_finetune_config():
parser = argparse.ArgumentParser(
description="Finetune a transformers model on a causal language modeling task"
)
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
parser.add_argument(
"--config_file",
type=str,
Expand Down Expand Up @@ -341,9 +327,7 @@ def main(external_config=None):
resources_per_worker = config["Training"].get("resources_per_worker")

if config["Training"].get("accelerate_mode", None) is None:
config["Training"][
"accelerate_mode"
] = "DDP" # will use DDP to accelerate if no method specified
config["Training"]["accelerate_mode"] = "DDP" # will use DDP to accelerate if no method specified

ccl_worker_count = 1
device = config["Training"]["device"]
Expand All @@ -366,9 +350,7 @@ def main(external_config=None):
runtime_env["pip"] = ["transformers==4.26.0"]

if device == "gpu":
num_cpus = (
resources_per_worker["CPU"] * num_training_workers + 1
) # additional 1 for head worker
num_cpus = resources_per_worker["CPU"] * num_training_workers + 1 # additional 1 for head worker
ray.init(num_cpus=num_cpus, runtime_env=runtime_env)
else:
ray.init(runtime_env=runtime_env)
Expand Down
4 changes: 2 additions & 2 deletions comps/finetuning/llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# limitations under the License.
#

from pydantic import BaseModel, validator
from typing import Optional, List
from typing import List, Optional

from pydantic import BaseModel, validator

PRECISION_BF16 = "bf16"
PRECISION_FP16 = "fp16"
Expand Down

0 comments on commit 4aa0840

Please sign in to comment.