Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kalman filter #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ loss.backwards() # Calculate the gradients.
grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)
### Option 2: Grokfast-MA (has argument window_size, lamb)
# grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb)
### Option 3: Grokfast-Kalman (has arguments process_noise, measurement_noise, lamb)
# grads = gradfilter_kalman(model, grads=grads, process_noise=process_noise, measurement_noise=measurement_noise, lamb=lamb

optimizer.step() # Call the optimizer.
# ... logging & other codes.
Expand Down Expand Up @@ -155,6 +157,14 @@ def gradfilter_ma(
- `warmup: bool = True`: If true, filter is not applied until the queue is filled.
- `trigger: bool = False`: For ablation study only. If true, the filter is simply not applied.

3. Grokfast-Kalman (`gradfilter_kalman`)

- `m: nn.Module`: Model that contains every trainable parameters.
- `grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None`: Running memory (Kalman filter state). Initialize by setting it to `None`. Feed the output of the method recursively after on.
- `process_noise: float = 1e-4`: Process noise parameter for the Kalman filter.
- `measurement_noise: float = 1e-2`: Measurement noise parameter for the Kalman filter.
- `lamb: float = 2.0`: Amplifying factor hyperparameter of the filter.

---

## Reproduction
Expand Down Expand Up @@ -242,9 +252,9 @@ python main_qm9.py --label test --alpha 0.9 --lamb 1.0 --weight_decay 0.01
These recommendations are based on my experiences during the experiments shown in the main manuscript. This may not work perfectly to every other problems, and maybe more intelligent techniques can do better jobs than this procedure. So, please take these as one possible starting guidelines for designing your own filters.


1. **Cutoff parameters**: The work uses MA/EMA filters to implement the filtering techniques. The cutoff frequency is determined by the _window size_ for the MA filter, and the _momentum parameter_ for the EMA filter.
1. **Cutoff parameters**: The work uses MA/EMA/Kalman filters to implement the filtering techniques. The cutoff frequency is determined by the _window size_ for the MA filter, the _momentum parameter_ for the EMA filter, and the _process noise_ and _measurement noise_ for the Kalman filter.
1. **Roughly figure out the amount of acceleration you want to achieve.** For example, in the main manuscript, the cutoff parameters are determined based on the original grokking report, where experiments shows generalization happening X100 slower than overfitting. Therefore, we want *N=100* times faster acceleration.
2. **Set the pivotal values for the cutoff parameter search.** For MA, I started to set the window size of "w=N=100" and for EMA, I began with the momentum parameter alpha that satisfies "alpha^{N} = alpha^{100} = 0.1" (which is roughly alpha ~ 0.98).
2. **Set the pivotal values for the cutoff parameter search.** For MA, I started to set the window size of "w=N=100" and for EMA, I began with the momentum parameter alpha that satisfies "alpha^{N} = alpha^{100} = 0.1" (which is roughly alpha ~ 0.98). For the Kalman filter, start with process_noise=1e-4 and measurement_noise=1e-2.** These are reasonable starting points, but you may need to adjust them based on your specific task.
3. **Perform hyperparameter search near the pivot values.** I swept across hyperparameter values near the values set in (1.b).
3. **Weight decay**: The weight decay is set in the optimizer constructor as usual (e.g., `optimizer = optim.Adam(m.parameters(), weight_decay=wd)`).
1. **Start from the default weight decay of that task.** For example, the value chosen by the most widely used Github repository of that task.
Expand Down
60 changes: 54 additions & 6 deletions grokfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@ def gradfilter_ma(
grads: Optional[Dict[str, deque]] = None,
window_size: int = 100,
lamb: float = 5.0,
filter_type: Literal['mean', 'sum'] = 'mean',
filter_type: Literal["mean", "sum"] = "mean",
warmup: bool = True,
trigger: bool = False, # For ablation study.
trigger: bool = False, # For ablation study.
) -> Dict[str, deque]:
if grads is None:
grads = {n: deque(maxlen=window_size) for n, p in m.named_parameters() if p.requires_grad and p.grad is not None}
grads = {
n: deque(maxlen=window_size)
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}

for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n].append(p.grad.data.detach()) # .cpu())
grads[n].append(p.grad.data.detach()) # .cpu())

# Modify the gradients.
if not warmup or len(grads[n]) == window_size and not trigger:
Expand All @@ -40,11 +44,55 @@ def gradfilter_ema(
lamb: float = 2.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad and p.grad is not None}
grads = {
n: p.grad.data.detach()
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}

for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb

return grads
return grads


def gradfilter_kalman(
m: nn.Module,
grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
process_noise: float = 1e-4,
measurement_noise: float = 1e-2,
lamb: float = 2.0,
) -> Dict[str, Dict[str, torch.Tensor]]:
if grads is None:
grads = {
n: {
"x": torch.zeros_like(p.grad.data),
"P": torch.ones_like(p.grad.data) * measurement_noise,
}
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}

for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
# Prediction step
x_pred = grads[n]["x"]
P_pred = grads[n]["P"] + process_noise

# Update step
y = p.grad.data - x_pred
S = P_pred + measurement_noise
K = P_pred / S
x = x_pred + K * y
P = (1 - K) * P_pred

# Store updated state
grads[n]["x"] = x
grads[n]["P"] = P

# Apply the filtered gradient
p.grad.data += x * lamb

return grads
96 changes: 62 additions & 34 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@


class Block(nn.Module):
"""Causal transformer block
"""
"""Causal transformer block"""

def __init__(self, dim, num_heads):
super().__init__()
Expand All @@ -33,7 +32,7 @@ def forward(self, x):
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device
attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device

x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
Expand All @@ -44,8 +43,7 @@ def forward(self, x):


class Decoder(nn.Module):
"""Causal Transformer decoder
"""
"""Causal Transformer decoder"""

def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5):
super().__init__()
Expand All @@ -71,8 +69,7 @@ def forward(self, x):


def multiplication_mod_p_data(p, eq_token, op_token):
"""x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p
"""
"""x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p"""
x = torch.arange(p)
y = torch.arange(1, p)
x, y = torch.cartesian_prod(x, y).T
Expand Down Expand Up @@ -107,7 +104,7 @@ def main(args):
).to(device)
nparams = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(model)
print(f'Total number of parameters: {nparams}')
print(f"Total number of parameters: {nparams}")

data = multiplication_mod_p_data(args.p, eq_token, op_token)

Expand Down Expand Up @@ -170,11 +167,29 @@ def main(args):
if args.filter == "none":
pass
elif args.filter == "ma":
grads = gradfilter_ma(model, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger)
grads = gradfilter_ma(
model,
grads=grads,
window_size=args.window_size,
lamb=args.lamb,
trigger=trigger,
)
elif args.filter == "ema":
grads = gradfilter_ema(model, grads=grads, alpha=args.alpha, lamb=args.lamb)
grads = gradfilter_ema(
model, grads=grads, alpha=args.alpha, lamb=args.lamb
)
elif args.filter == "kal":
grads = gradfilter_kalman(
model,
grads=grads,
process_noise=args.process_noise,
measurement_noise=args.measurement_noise,
lamb=args.lamb,
)
else:
raise ValueError(f"Invalid gradient filter type `{args.filter}`")
raise ValueError(
f"Invalid gradient filter type `{args.filter}`"
)

#######

Expand All @@ -194,7 +209,11 @@ def main(args):
val_loss.append(total_loss / valid_data.shape[-1])

if args.save_weights:
do_save = e <= 500 or (e > 500 and (e + 1) % 100 == 0) or e == int(args.budget) // steps_per_epoch - 1
do_save = (
e <= 500
or (e > 500 and (e + 1) % 100 == 0)
or e == int(args.budget) // steps_per_epoch - 1
)
else:
do_save = (e + 1) % 100 == 0
if do_save:
Expand Down Expand Up @@ -222,18 +241,18 @@ def main(args):
plt.close()

results = {
'its': its,
'train_acc': train_acc,
'train_loss': train_loss,
'val_acc': val_acc,
'val_loss': val_loss,
"its": its,
"train_acc": train_acc,
"train_loss": train_loss,
"val_acc": val_acc,
"val_loss": val_loss,
}

if args.save_weights:
net_its.append(e)
nets.append(copy.deepcopy(model.state_dict()))
results['net_its'] = net_its
results['net'] = nets
results["net_its"] = net_its
results["net"] = nets

torch.save(results, f"results/res_{args.label}.pt")

Expand All @@ -252,37 +271,46 @@ def main(args):
parser.add_argument("--optimizer", default="Adam")

# Grokfast
parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")
parser.add_argument(
"--filter", type=str, choices=["none", "ma", "ema", "kal"], default="none"
)
parser.add_argument("--alpha", type=float, default=0.99)
parser.add_argument("--window_size", type=int, default=100)
parser.add_argument("--lamb", type=float, default=5.0)
parser.add_argument("--process_noise", type=float, default=1e-4)
parser.add_argument("--measurement_noise", type=float, default=1e-2)

# Ablation studies
parser.add_argument("--two_stage", action='store_true')
parser.add_argument("--save_weights", action='store_true')
parser.add_argument("--two_stage", action="store_true")
parser.add_argument("--save_weights", action="store_true")
args = parser.parse_args()

filter_str = ('_' if args.label != '' else '') + args.filter
window_size_str = f'_w{args.window_size}'
alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')
lamb_str = f'_l{int(args.lamb)}'
filter_str = ("_" if args.label != "" else "") + args.filter
window_size_str = f"_w{args.window_size}"
alpha_str = f"_a{args.alpha:.3f}".replace(".", "")
lamb_str = f"_l{int(args.lamb)}"

if args.filter == 'none':
filter_suffix = ''
elif args.filter == 'ma':
if args.filter == "none":
filter_suffix = ""
elif args.filter == "ma":
filter_suffix = window_size_str + lamb_str
elif args.filter == 'ema':
elif args.filter == "ema":
filter_suffix = alpha_str + lamb_str
elif args.filter == "kal":
filter_suffix = (
f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "")
+ lamb_str
)
else:
raise ValueError(f"Unrecognized filter type {args.filter}")

optim_suffix = ''
optim_suffix = ""
if args.weight_decay != 0:
optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')
optim_suffix = optim_suffix + f"_wd{args.weight_decay:.1e}".replace(".", "")
if args.lr != 1e-3:
optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'
optim_suffix = optim_suffix + f"_lrx{int(args.lr / 1e-3)}"

args.label = args.label + filter_str + filter_suffix + optim_suffix
print(f'Experiment results saved under name: {args.label}')
print(f"Experiment results saved under name: {args.label}")

main(args)
Loading