Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resume for adapter_v2, enable continued finetuning for adapter #1354

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
load_checkpoint_update,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
Expand All @@ -43,6 +44,7 @@ def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
resume: Optional[bool] = False,
data: Optional[DataModule] = None,
train: TrainArgs = TrainArgs(
save_interval=1000,
Expand Down Expand Up @@ -110,7 +112,7 @@ def setup(
strategy = "auto"

fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval)
fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval)


def main(
Expand All @@ -119,6 +121,7 @@ def main(
seed: int,
config: Config,
data: DataModule,
resume: bool,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
Expand Down Expand Up @@ -149,7 +152,6 @@ def main(
trainable_params = [p for p in model.parameters() if p.requires_grad]
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
import bitsandbytes as bnb

optimizer_cls = bnb.optim.PagedAdamW
else:
optimizer_cls = torch.optim.AdamW
Expand All @@ -158,10 +160,23 @@ def main(
)
optimizer = fabric.setup_optimizers(optimizer)
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)

if resume == True:
# Finding last trace of adapter training
try:
resume = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=(lambda p: int(p.parent.name.split("-")[1])))
fabric.print(f"Resuming training from {resume}")
load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False)
resume = True
except ValueError:
fabric.print("No previous adapter found. Finetune from start.")
resume = False
load_checkpoint(fabric, model, checkpoint_path, strict=False)
else:
# strict=False because missing keys due to Adapter weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)

load_checkpoint(fabric, model, checkpoint_path, strict=False)

mark_only_adapter_v2_as_trainable(model)

train_time = time.perf_counter()
fit(
fabric,
Expand All @@ -171,6 +186,7 @@ def main(
train_dataloader,
val_dataloader,
devices,
resume,
checkpoint_dir,
out_dir,
train,
Expand Down Expand Up @@ -206,6 +222,7 @@ def fit(
train_dataloader: DataLoader,
val_dataloader: DataLoader,
devices: int,
resume: bool,
checkpoint_dir: Path,
out_dir: Path,
train: TrainArgs,
Expand Down Expand Up @@ -234,6 +251,14 @@ def fit(
total_t0 = time.perf_counter()
val_loss = "n/a"

if resume != False:
try:
iter_match = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=lambda p: int(p.parent.name.split("-")[1]))
step_count = int(iter_match.parent.name.split("-")[1]) if iter_match else 0
except ValueError:
step_count = 0

fabric.print(f"Starting at step count {step_count}")
while step_count < max_steps and train_iterator.epoch < train.epochs:
iter_num += 1
iter_t0 = time.perf_counter()
Expand Down
10 changes: 10 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, s
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)

def load_checkpoint_update(fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
else:
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
adapter_cp = lazy_load(adapter_path)
state_dict.update(adapter_cp)
model.load_state_dict(state_dict, strict=strict)


def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
Expand Down
Loading