Skip to content

Commit

Permalink
Merge branch 'main' into v2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jun 17, 2024
2 parents 77ad5b5 + f6c9e84 commit e994dac
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 31 deletions.
11 changes: 6 additions & 5 deletions swift/llm/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.dpo_trainers import DPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (DPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -156,6 +156,7 @@ def llm_dpo(args: DPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -227,9 +228,9 @@ def llm_dpo(args: DPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
11 changes: 6 additions & 5 deletions swift/llm/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.orpo_trainers import ORPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (ORPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -147,6 +147,7 @@ def llm_orpo(args: ORPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -230,9 +231,9 @@ def llm_orpo(args: ORPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
1 change: 1 addition & 0 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
val_dataset = LazyLLMDataset(val_dataset, template)
if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False

padding_to = args.max_length if args.sft_type == 'longlora' else None
Expand Down
11 changes: 6 additions & 5 deletions swift/llm/simpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.simpo_trainers import SimPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (SimPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -145,6 +145,7 @@ def llm_simpo(args: SimPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -215,9 +216,9 @@ def llm_simpo(args: SimPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
54 changes: 44 additions & 10 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class LoRATM(NamedTuple):
glm4v = ['self_attention.query_key_value']
phi = ['Wqkv']
phi3 = ['qkv_proj']
phi3_small = ['query_key_value'] # what the hell???
internlm2 = ['wqkv']
mamba = ['in_proj', 'x_proj', 'embeddings', 'out_proj']
telechat = ['key_value', 'query']
Expand Down Expand Up @@ -1605,16 +1606,6 @@ def _output_device_map_hook(module, input, output):
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-mini-128k-instruct')
@register_model(
ModelType.phi3_small_128k_instruct,
'LLM-Research/Phi-3-small-128k-instruct',
LoRATM.phi3,
TemplateType.phi3,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-small-128k-instruct')
@register_model(
ModelType.phi3_medium_128k_instruct,
'LLM-Research/Phi-3-medium-128k-instruct',
Expand Down Expand Up @@ -2361,6 +2352,49 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)


@register_model(
ModelType.phi3_small_128k_instruct,
'LLM-Research/Phi-3-small-128k-instruct',
LoRATM.phi3_small,
TemplateType.phi3,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_gradient_checkpointing=False,
support_vllm=True,
tags=['general'],
hf_model_id='microsoft/Phi-3-small-128k-instruct')
def get_model_tokenizer_phi3_small(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
**kwargs):
if model_config is None:
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
use_flash_attn = kwargs.pop('use_flash_attn', False)
if version.parse(transformers.__version__) >= version.parse('4.36'):
if use_flash_attn:
model_config._attn_implementation = 'flash_attention_2'
else:
model_config._flash_attn_2_enabled = use_flash_attn
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)

def rotary_emb(self, query_states, key_states, **kwargs):
q_type = query_states.dtype
k_type = key_states.dtype
query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs)
query_states = query_states.to(q_type)
key_states = key_states.to(k_type)
return query_states, key_states

for i in range(32):
re = model.model.layers[i].self_attn.rotary_emb
re.rotory_emb_origin = re.forward
re.forward = MethodType(rotary_emb, re)
return model, tokenizer


@register_model(
ModelType.qwen2_57b_a14b_instruct_int4,
'qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4',
Expand Down
11 changes: 5 additions & 6 deletions swift/trainers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TrainerState)
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics

from swift.utils import is_pai_training_job, use_torchacc
from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
from .arguments import TrainingArguments


Expand Down Expand Up @@ -55,8 +55,7 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non
logs[k] = round(logs[k], 8)
if not is_pai_training_job() and state.is_local_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(logs) + '\n')
append_to_jsonl(jsonl_path, logs)
super().on_log(args, state, control, logs, **kwargs)
if state.is_local_process_zero and self.training_bar is not None:
self.training_bar.refresh()
Expand All @@ -67,8 +66,9 @@ class DefaultFlowCallbackNew(DefaultFlowCallback):
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
control = super().on_step_end(args, state, control, **kwargs)
# save the last ckpt
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
if state.global_step == state.max_steps:
if args.evaluation_strategy != IntervalStrategy.NO:
if evaluation_strategy != IntervalStrategy.NO:
control.should_evaluate = True
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True
Expand All @@ -84,8 +84,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
logs[k] = round(logs[k], 8)
if not is_pai_training_job() and state.is_local_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(logs) + '\n')
append_to_jsonl(jsonl_path, logs)

_ = logs.pop('total_flos', None)
if state.is_local_process_zero:
Expand Down

0 comments on commit e994dac

Please sign in to comment.