-
Notifications
You must be signed in to change notification settings - Fork 47
/
launch.py
316 lines (267 loc) · 13.2 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Copyright (c) 2023 Contextual AI, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Main script for training.
Sample use is:
accelerate launch --config_file fsdp_config.yaml --main_process_port 29501 launch.py loss=kto model=llama datasets=[ultrabin] exp_name=llama3-8B-kto-default mode=train ++cache_dir=/nlp/scr2/kawin/models ++model.name_or_path=meta-llama/Meta-Llama-3-8B
where
- loss should have a file under config/loss that specifies the trainer in trainers.py and dataloader in dataloader.py
- model should have a file under config/model
- datasets is a list of datasets, each of which has a get_{name} function in dataloader.py
- exp_name is the experiment name (on WANDB); model will be saved to the cache_dir/exp_name
- model.load_from should be used for aligning a model that has already been finetuned
Remember to allocate enough RAM before running this (you need aroundd 800 GB for Llama-13B).
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
from train.utils import disable_dropout
from train.models import AutoModelForCausalLMWithValueHead, ReferenceModelWrapper, AutoModelForBradleyTerry
from train import trainers
from train import dataloader
from train import models
import os
import hydra
from omegaconf import OmegaConf, DictConfig
import wandb
import json
from typing import Optional, Set
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from accelerate import Accelerator, DistributedDataParallelKwargs
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
def main(config: DictConfig):
"""Main entry point for training. Validates config, creates/initializes model(s), and starts training."""
# Resolve hydra references, e.g. so we don't re-compute the run directory
OmegaConf.resolve(config)
missing_keys: Set[str] = OmegaConf.missing_keys(config)
if missing_keys:
raise ValueError(f"Got missing keys in config:\n{missing_keys}")
set_seed(config.seed)
# Initialize Accelerator
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
project_dir=config.local_run_dir,
gradient_accumulation_steps=config.model.gradient_accumulation_steps,
kwargs_handlers=[ddp_kwargs]
)
if accelerator.state.fsdp_plugin is not None:
accelerator.state.fsdp_plugin.transformer_layer_cls_to_wrap = config.model.block_name
# Calculate microbatch sizes
if config.model.batch_size % accelerator.num_processes == 0:
config.model.microbatch_size = config.model.batch_size / accelerator.num_processes
else:
raise ValueError(f"{config.model.batch_size} needs to be divisible by the number of processes")
if config.model.eval_batch_size % accelerator.num_processes == 0:
config.model.eval_microbatch_size = config.model.eval_batch_size / accelerator.num_processes
else:
raise ValueError(f"{config.model.eval_batch_size} needs to be divisible by the number of processes")
if config.eval_every % config.model.batch_size != 0:
accelerator.print('WARNING: eval_every must be divisible by batch_size')
accelerator.print('Setting eval_every to', config.eval_every - config.eval_every % config.model.batch_size)
config.eval_every = config.eval_every - config.eval_every % config.model.batch_size
accelerator.print(OmegaConf.to_yaml(config))
if accelerator.is_main_process:
os.makedirs(config.local_run_dir, exist_ok=True)
accelerator.print("Making experiment directory", config.local_run_dir)
if config.wandb.enabled:
os.environ['WANDB_CACHE_DIR'] = config.cache_dir
wandb.init(
entity=config.wandb.entity,
project=config.wandb.project,
config=OmegaConf.to_container(config),
dir=config.cache_dir,
name=config.exp_name,
)
config_path = os.path.join(config.local_run_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config, f)
accelerator.print('=' * 80)
accelerator.print(f'Writing to {config.local_run_dir}')
accelerator.print('=' * 80)
# Prepare tokenizer
tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path
accelerator.print(f'Loading tokenizer {tokenizer_name_or_path}')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
special_tokens = []
# Check if the tokenizer has a chat template and set a default one if it doesn't
if not tokenizer.chat_template:
with open("config/template.jinja") as f:
tokenizer.chat_template = f.read()
accelerator.print("Default chat template set.")
control_tokens = list(config.loss.get("control_tokens", {}).values())
special_tokens.extend(control_tokens)
num_tokens_added = tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
# Create data loaders
accelerator.print(f'Loading data')
data_loader_class = getattr(dataloader, config.loss.dataloader)
data_iterator_kwargs = dict(
process_index=accelerator.process_index,
num_processes=accelerator.num_processes,
max_length=config.model.max_length,
max_prompt_length=config.model.max_prompt_length,
seed=config.seed,
frac_unique_desirable=config.frac_unique_desirable,
frac_unique_undesirable=config.frac_unique_undesirable,
control_tokens=config.loss.get("control_tokens", {}),
)
train_iterator = data_loader_class(
config.datasets,
tokenizer,
split='train',
microbatch_size=config.model.microbatch_size,
n_epochs=config.n_epochs,
n_examples=config.n_examples,
**data_iterator_kwargs
)
eval_iterator = data_loader_class(
config.datasets,
tokenizer,
split='test',
microbatch_size=config.model.eval_microbatch_size,
n_examples=config.n_eval_examples,
n_epochs=(1 if config.n_eval_examples is None else None),
**data_iterator_kwargs
)
TrainerClass = getattr(trainers, config.loss.trainer)
# Building reference
if TrainerClass.use_reference_model:
reference_cls = TrainerClass.reference_hf_model_class
reference_kwargs = {
'torch_dtype': getattr(torch, config.model.reference_dtype),
'attn_implementation' : config.model.attn_implementation if config.model.policy_dtype in ["float16", "bfloat16"] else "eager",
}
reference_path = config.model.load_from or config.model.name_or_path
accelerator.print(f'Loading reference model from {reference_path}')
reference_model = reference_cls.from_pretrained(reference_path, **reference_kwargs)
if config.model.activation_checkpointing:
reference_model.gradient_checkpointing_enable()
if num_tokens_added:
reference_model.resize_token_embeddings(len(tokenizer))
reference_model.eval()
if config.cache_reference_logprobs:
reference_accelerator = Accelerator(
project_dir=config.local_run_dir,
gradient_accumulation_steps=config.model.gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
)
if reference_accelerator.state.fsdp_plugin is not None:
reference_accelerator.state.fsdp_plugin.transformer_layer_cls_to_wrap = config.model.block_name
reference_accelerator.print("precomputing logprobs ...")
reference_model = ReferenceModelWrapper(
reference_accelerator,
reference_model,
tokenizer,
config,
iterators=[train_iterator, eval_iterator],
)
else:
reference_model = None
# Building policy
policy_cls = TrainerClass.policy_hf_model_class
policy_kwargs = {
'torch_dtype': getattr(torch, config.model.policy_dtype),
'attn_implementation' : config.model.attn_implementation if config.model.policy_dtype in ["float16", "bfloat16"] else "eager",
}
# first see if you need to load from checkpoint, a local pretrained model, or a remote pretrained model
policy_path = config.model.from_checkpoint or config.model.load_from or config.model.name_or_path
accelerator.print(f'Loading policy from {policy_path}')
policy = policy_cls.from_pretrained(policy_path, **policy_kwargs)
if num_tokens_added:
policy.resize_token_embeddings(len(tokenizer))
if config.model.use_peft:
# if there's a value head, then peft should only be applied to the base model
base_model = policy.pretrained_model if TrainerClass.policy_hf_model_class == AutoModelForCausalLMWithValueHead else policy
base_model.enable_input_require_grads()
if config.model.load_lora_from:
peft_model = PeftModel.from_pretrained(
base_model,
config.model.load_lora_from,
torch_dtype=getattr(torch, config.model.policy_dtype)
)
else:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=config.model.peft.lora_r,
lora_alpha=config.model.peft.lora_alpha,
lora_dropout=config.model.peft.lora_dropout,
bias="none",
target_modules=config.model.peft.target_modules,
inference_mode=False,
)
peft_model = get_peft_model(base_model, peft_config)
# Ensure LoRA layers are in the same dtype as the base model
for name, module in peft_model.named_modules():
if 'lora_' in name:
module.to(getattr(torch, config.model.policy_dtype))
if TrainerClass.policy_hf_model_class == AutoModelForCausalLMWithValueHead:
policy.pretrained_model = peft_model
else:
policy = peft_model
else:
peft_config = None
if config.model.activation_checkpointing:
policy.gradient_checkpointing_enable()
# Loading optimizer, scheduler
accelerator.print("Creating optimizer and scheduler")
optimizer = getattr(torch.optim, config.optimizer)(policy.parameters(), lr=config.lr)
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=config.warmup_steps)
main_scheduler = CosineAnnealingLR(optimizer, T_max=train_iterator.num_training_steps - config.warmup_steps, eta_min=0)
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[config.warmup_steps])
if config.model.from_checkpoint:
optimizer_state = optimizer.state_dict()
optimizer_state.update(torch.load(os.path.join(config.model.from_checkpoint, "optimizer.pt"), map_location='cpu'))
optimizer.load_state_dict(optimizer_state)
scheduler_state = torch.load(os.path.join(config.model.from_checkpoint, "scheduler.pt"))
scheduler.load_state_dict(scheduler_state)
metrics = json.load(open(os.path.join(config.model.from_checkpoint, 'metrics.json')))
num_skip_batches = int(metrics.get('counter', 0) / config.model.batch_size)
else:
num_skip_batches = 0
# Load explicit reward model if necessary (e.g., for PPO)
if config.model.reward_model.path:
accelerator.print(f'Loading reward model from {config.model.reward_model.path}')
reward_tokenizer = AutoTokenizer.from_pretrained(config.model.reward_model.path)
if reward_tokenizer.pad_token_id is None:
reward_tokenizer.pad_token_id = reward_tokenizer.eos_token_id
reward_hf_model_class = getattr(models, config.model.reward_model.model_class)
reward_kwargs = {
'torch_dtype': getattr(torch, config.model.reward_model.dtype),
'attn_implementation' : config.model.reward_model.attn_implementation if config.model.reward_model.dtype in ["float16", "bfloat16"] else "eager",
}
reward_model = reward_hf_model_class.from_pretrained(config.model.reward_model.path, **reward_kwargs)
else:
reward_model, reward_tokenizer = None, None
# Initialize trainer
trainer = TrainerClass(
tokenizer,
config,
train_iterator,
eval_iterator,
accelerator,
optimizer,
scheduler,
policy,
reference_model=reference_model,
num_skip_batches=num_skip_batches,
reward_model=reward_model,
reward_tokenizer=reward_tokenizer,
)
trainer.train()
trainer.save(
os.path.join(config.local_run_dir, 'FINAL'),
metrics={'counter': trainer.example_counter}
)
accelerator.end_training()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@hydra.main(version_base=None, config_path="config", config_name="config")
def hydra_main(config: DictConfig):
main(config)
if __name__ == '__main__':
hydra_main()