diff --git a/swift/llm/dpo.py b/swift/llm/dpo.py index 964276cec..8775cb290 100644 --- a/swift/llm/dpo.py +++ b/swift/llm/dpo.py @@ -10,8 +10,8 @@ from transformers.utils import is_torch_npu_available from swift.trainers.dpo_trainers import DPOTrainer -from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp, - is_dist, is_master, plot_images, seed_everything, show_layers) +from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info, + is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers) from .tuner import prepare_model from .utils import (DPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info, set_generation_config) @@ -156,6 +156,7 @@ def llm_dpo(args: DPOArguments) -> str: if val_dataset is None: training_args.evaluation_strategy = IntervalStrategy.NO + training_args.eval_strategy = IntervalStrategy.NO training_args.do_eval = False logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') @@ -227,9 +228,9 @@ def llm_dpo(args: DPOArguments) -> str: 'model_info': model_info, 'dataset_info': trainer.dataset_info, } - jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') - with open(jsonl_path, 'a', encoding='utf-8') as f: - f.write(json.dumps(run_info) + '\n') + if is_local_master(): + jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, run_info) return run_info diff --git a/swift/llm/orpo.py b/swift/llm/orpo.py index 6c47f1a0e..db2a15567 100644 --- a/swift/llm/orpo.py +++ b/swift/llm/orpo.py @@ -10,8 +10,8 @@ from transformers.utils import is_torch_npu_available from swift.trainers.orpo_trainers import ORPOTrainer -from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp, - is_dist, is_master, plot_images, seed_everything, show_layers) +from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info, + is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers) from .tuner import prepare_model from .utils import (ORPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info, set_generation_config) @@ -147,6 +147,7 @@ def llm_orpo(args: ORPOArguments) -> str: if val_dataset is None: training_args.evaluation_strategy = IntervalStrategy.NO + training_args.eval_strategy = IntervalStrategy.NO training_args.do_eval = False logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') @@ -230,9 +231,9 @@ def llm_orpo(args: ORPOArguments) -> str: 'model_info': model_info, 'dataset_info': trainer.dataset_info, } - jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') - with open(jsonl_path, 'a', encoding='utf-8') as f: - f.write(json.dumps(run_info) + '\n') + if is_local_master(): + jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, run_info) return run_info diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 5bf529187..4b3fb4ebf 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -255,6 +255,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]: val_dataset = LazyLLMDataset(val_dataset, template) if val_dataset is None: training_args.evaluation_strategy = IntervalStrategy.NO + training_args.eval_strategy = IntervalStrategy.NO training_args.do_eval = False padding_to = args.max_length if args.sft_type == 'longlora' else None diff --git a/swift/llm/simpo.py b/swift/llm/simpo.py index cdbb6dea4..fd788bdca 100644 --- a/swift/llm/simpo.py +++ b/swift/llm/simpo.py @@ -10,8 +10,8 @@ from transformers.utils import is_torch_npu_available from swift.trainers.simpo_trainers import SimPOTrainer -from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp, - is_dist, is_master, plot_images, seed_everything, show_layers) +from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info, + is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers) from .tuner import prepare_model from .utils import (SimPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info, set_generation_config) @@ -145,6 +145,7 @@ def llm_simpo(args: SimPOArguments) -> str: if val_dataset is None: training_args.evaluation_strategy = IntervalStrategy.NO + training_args.eval_strategy = IntervalStrategy.NO training_args.do_eval = False logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') @@ -215,9 +216,9 @@ def llm_simpo(args: SimPOArguments) -> str: 'model_info': model_info, 'dataset_info': trainer.dataset_info, } - jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') - with open(jsonl_path, 'a', encoding='utf-8') as f: - f.write(json.dumps(run_info) + '\n') + if is_local_master(): + jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') + append_to_jsonl(jsonl_path, run_info) return run_info diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index abebce951..ea6773c05 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -451,6 +451,7 @@ class LoRATM(NamedTuple): glm4v = ['self_attention.query_key_value'] phi = ['Wqkv'] phi3 = ['qkv_proj'] + phi3_small = ['query_key_value'] # what the hell??? internlm2 = ['wqkv'] mamba = ['in_proj', 'x_proj', 'embeddings', 'out_proj'] telechat = ['key_value', 'query'] @@ -1605,16 +1606,6 @@ def _output_device_map_hook(module, input, output): support_vllm=True, tags=['general'], hf_model_id='microsoft/Phi-3-mini-128k-instruct') -@register_model( - ModelType.phi3_small_128k_instruct, - 'LLM-Research/Phi-3-small-128k-instruct', - LoRATM.phi3, - TemplateType.phi3, - requires=['transformers>=4.36'], - support_flash_attn=True, - support_vllm=True, - tags=['general'], - hf_model_id='microsoft/Phi-3-small-128k-instruct') @register_model( ModelType.phi3_medium_128k_instruct, 'LLM-Research/Phi-3-medium-128k-instruct', @@ -2361,6 +2352,49 @@ def get_model_tokenizer_with_flash_attn(model_dir: str, model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs) +@register_model( + ModelType.phi3_small_128k_instruct, + 'LLM-Research/Phi-3-small-128k-instruct', + LoRATM.phi3_small, + TemplateType.phi3, + requires=['transformers>=4.36'], + support_flash_attn=True, + support_gradient_checkpointing=False, + support_vllm=True, + tags=['general'], + hf_model_id='microsoft/Phi-3-small-128k-instruct') +def get_model_tokenizer_phi3_small(model_dir: str, + torch_dtype: Dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + model_config=None, + **kwargs): + if model_config is None: + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + use_flash_attn = kwargs.pop('use_flash_attn', False) + if version.parse(transformers.__version__) >= version.parse('4.36'): + if use_flash_attn: + model_config._attn_implementation = 'flash_attention_2' + else: + model_config._flash_attn_2_enabled = use_flash_attn + model, tokenizer = get_model_tokenizer_from_repo( + model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs) + + def rotary_emb(self, query_states, key_states, **kwargs): + q_type = query_states.dtype + k_type = key_states.dtype + query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs) + query_states = query_states.to(q_type) + key_states = key_states.to(k_type) + return query_states, key_states + + for i in range(32): + re = model.model.layers[i].self_attn.rotary_emb + re.rotory_emb_origin = re.forward + re.forward = MethodType(rotary_emb, re) + return model, tokenizer + + @register_model( ModelType.qwen2_57b_a14b_instruct_int4, 'qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4', diff --git a/swift/trainers/callback.py b/swift/trainers/callback.py index a787bf049..87dc50fb0 100644 --- a/swift/trainers/callback.py +++ b/swift/trainers/callback.py @@ -8,7 +8,7 @@ TrainerState) from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics -from swift.utils import is_pai_training_job, use_torchacc +from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc from .arguments import TrainingArguments @@ -55,8 +55,7 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non logs[k] = round(logs[k], 8) if not is_pai_training_job() and state.is_local_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') - with open(jsonl_path, 'a', encoding='utf-8') as f: - f.write(json.dumps(logs) + '\n') + append_to_jsonl(jsonl_path, logs) super().on_log(args, state, control, logs, **kwargs) if state.is_local_process_zero and self.training_bar is not None: self.training_bar.refresh() @@ -67,8 +66,9 @@ class DefaultFlowCallbackNew(DefaultFlowCallback): def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): control = super().on_step_end(args, state, control, **kwargs) # save the last ckpt + evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy if state.global_step == state.max_steps: - if args.evaluation_strategy != IntervalStrategy.NO: + if evaluation_strategy != IntervalStrategy.NO: control.should_evaluate = True if args.save_strategy != IntervalStrategy.NO: control.should_save = True @@ -84,8 +84,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): logs[k] = round(logs[k], 8) if not is_pai_training_job() and state.is_local_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') - with open(jsonl_path, 'a', encoding='utf-8') as f: - f.write(json.dumps(logs) + '\n') + append_to_jsonl(jsonl_path, logs) _ = logs.pop('total_flos', None) if state.is_local_process_zero: