Skip to content

Commit

Permalink
👯 Standardize model_args (#2442)
Browse files Browse the repository at this point in the history
* `model_config` -> `model_args`

* sort
  • Loading branch information
qgallouedec authored Dec 10, 2024
1 parent 7ba118a commit 460e780
Show file tree
Hide file tree
Showing 20 changed files with 184 additions and 203 deletions.
22 changes: 11 additions & 11 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -468,30 +468,30 @@ We included a utility function to create your model.

```python
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_config = ModelConfig(
model_args = ModelConfig(
model_name_or_path="facebook/opt-350m"
attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
...,
model=model_config.model_name_or_path,
peft_config=get_peft_config(model_config),
model=model_args.model_name_or_path,
peft_config=get_peft_config(model_args),
)
```

Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@

if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
script_args, training_args, model_args = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -94,7 +94,7 @@
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)

# train and save the model
Expand Down
20 changes: 9 additions & 11 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,37 +66,35 @@

if __name__ == "__main__":
parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args, training_args, model_args = parser.parse_args_and_config()

################
# Model & Tokenizer
###################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
peft_config = get_peft_config(model_config)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
else:
ref_model = None
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down
24 changes: 11 additions & 13 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,38 +65,36 @@

if __name__ == "__main__":
parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)

if training_args.reward_model_path is not None:
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
reward_tokenizer = AutoTokenizer.from_pretrained(
training_args.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
truncation=True,
truncation_side="left", # since we judge the completion, truncating left is more appropriate
)
Expand All @@ -111,9 +109,9 @@
judge = None

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
model_args.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if tokenizer.chat_template is None:
Expand All @@ -132,7 +130,7 @@
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
reward_processing_class=reward_tokenizer,
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)

if training_args.eval_strategy != "no":
Expand Down
26 changes: 12 additions & 14 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,42 +45,40 @@

if __name__ == "__main__":
parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args, training_args, model_args = parser.parse_args_and_config()

################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_config)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
ref_model = None
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
do_image_splitting=False,
)
tokenizer = processor.tokenizer
Expand Down
30 changes: 15 additions & 15 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,38 +64,38 @@

if __name__ == "__main__":
parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args, training_args, model_args = parser.parse_args_and_config()

################
# Model & Tokenizer
################
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs

teacher_model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.teacher_model_init_kwargs = teacher_model_kwargs

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
if tokenizer.pad_token is None:
Expand All @@ -118,13 +118,13 @@
# Training
################
trainer = GKDTrainer(
model=model_config.model_name_or_path,
model=model_args.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)

if training_args.eval_strategy != "no":
Expand Down
22 changes: 10 additions & 12 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,34 @@

if __name__ == "__main__":
parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)

model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)

if training_args.reward_model_path is not None:
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
Expand All @@ -112,9 +110,9 @@
judge = None

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
model_args.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
trust_remote_code=model_args.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@

if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
script_args, training_args, model_args = parser.parse_args_into_dataclasses()

################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -94,7 +94,7 @@
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)

# train and save the model
Expand Down
Loading

0 comments on commit 460e780

Please sign in to comment.