-
Notifications
You must be signed in to change notification settings - Fork 3
/
callback.py
127 lines (111 loc) · 5.46 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os.path
from typing import Any
import hivemind
import torch
import transformers
from transformers import TrainingArguments
from arguments import TrainingPeerArguments
from task import TrainingTask
from utils import LocalMetrics, logger
class CollaborativeCallback(transformers.TrainerCallback):
"""
This callback monitors and reports collaborative training progress,
In case of a catastrophic failure, it can also revert training to a backup
"""
def __init__(self, task: TrainingTask, args: TrainingPeerArguments):
super().__init__()
self.task = task
self.dht, self.collaborative_optimizer = task.dht, task.collaborative_optimizer
self.statistics_expiration = args.statistics_expiration
self.last_reported_collaboration_step = -1
self.samples = 0
self.steps = 0
self.loss = 0
self.total_samples_processed = 0
self.backup_every_steps = args.backup_every_steps
self.state_path = args.state_path
def on_train_begin(
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
):
if os.path.isfile(self.state_path):
self.restore_from_backup(self.state_path)
logger.info("Loaded state")
logger.info("Loading state from peers")
self.collaborative_optimizer.load_state_from_peers()
if os.path.isfile(self.state_path):
self.restore_from_backup(self.state_path, check_step=True)
def on_step_end(
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
):
control.should_log = True
if not self.params_are_finite():
if not os.path.exists(self.state_path):
raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.")
logger.warning("Parameters are invalid, reloading model from earlier state")
self.restore_from_backup(self.state_path)
return control
if state.log_history:
self.loss += state.log_history[-1]["loss"]
self.steps += 1
if self.collaborative_optimizer.local_epoch != self.last_reported_collaboration_step:
self.last_reported_collaboration_step = self.collaborative_optimizer.local_epoch
self.total_samples_processed += self.samples
samples_per_second = self.collaborative_optimizer.tracker.performance_ema.samples_per_second
statistics = LocalMetrics(
step=self.collaborative_optimizer.local_epoch,
samples_per_second=samples_per_second,
samples_accumulated=self.samples,
loss=self.loss,
mini_steps=self.steps,
)
logger.info(f"Current epoch: {self.collaborative_optimizer.local_epoch}")
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
logger.info(f"Performance: {samples_per_second} samples/sec")
if self.steps:
logger.info(f"Local loss: {self.loss / self.steps}")
self.loss = 0
self.steps = 0
if self.collaborative_optimizer.local_epoch == self.collaborative_optimizer.tracker.global_epoch:
self.dht.store(
key=self.collaborative_optimizer.run_id + "_metrics",
subkey=self.task.local_public_key,
value=statistics.dict(),
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
return_future=True,
)
if self.backup_every_steps is not None and \
self.collaborative_optimizer.local_epoch % self.backup_every_steps == 0:
self.backup_state()
self.samples = self.collaborative_optimizer.grad_averager.local_samples_accumulated
return control
@torch.no_grad()
def params_are_finite(self):
for param in self.task.model.parameters():
if not torch.all(torch.isfinite(param)):
return False
return True
@torch.no_grad()
def backup_state(self) -> Any:
logger.info("Saving backup")
return torch.save(
{
"model": self.task.model.state_dict(),
"training": self.collaborative_optimizer.state_dict(),
"scheduler": self.collaborative_optimizer.state_averager.scheduler.state_dict(),
"local_epoch": self.collaborative_optimizer.local_epoch,
},
self.state_path,
)
@torch.no_grad()
def restore_from_backup(self, path, check_step=False):
state = torch.load(path)
current_step = self.collaborative_optimizer.local_epoch
backup_step = state['local_epoch']
if not check_step or backup_step >= current_step:
self.task.model.load_state_dict(state["model"], strict=False)
self.collaborative_optimizer.load_state_dict(state["training"])
self.collaborative_optimizer.state_averager.scheduler.load_state_dict(state["scheduler"])
self.collaborative_optimizer.state_averager.local_epoch = backup_step
logger.info("Restored from a backup")
else:
logger.info("Bypassed restoring state from local backup: backup state is too old.")