From 8f5b4923c8caca1f352581eb6f2fda583517b1a6 Mon Sep 17 00:00:00 2001 From: zhangsibo1129 <134488188+zhangsibo1129@users.noreply.github.com> Date: Sat, 23 Dec 2023 17:16:27 +0800 Subject: [PATCH] reformatted (#1128) --- examples/scripts/ppo_multi_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/scripts/ppo_multi_adapter.py b/examples/scripts/ppo_multi_adapter.py index 044a2419de..2bd489dfbd 100644 --- a/examples/scripts/ppo_multi_adapter.py +++ b/examples/scripts/ppo_multi_adapter.py @@ -21,8 +21,9 @@ from tqdm import tqdm from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, is_xpu_available +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer from trl.core import LengthSampler +from trl.import_utils import is_npu_available, is_xpu_available input_min_text_length = 6 @@ -82,7 +83,7 @@ def tokenize(example): ) model = AutoModelForCausalLMWithValueHead.from_pretrained( script_args.model_name, - device_map={"": "xpu:0"} if is_xpu_available() else {"": 0}, + device_map={"": "xpu:0"} if is_xpu_available() else {"": "npu:0"} if is_npu_available else {"": 0}, peft_config=lora_config, quantization_config=nf4_config, reward_adapter=script_args.rm_adapter,