-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/kalman filter #8
base: main
Are you sure you want to change the base?
Conversation
Thank you for the interesting update! I believe this could possibly be another publishable work by itself if the results are promising. I will hold this open and unmerged for now until any experimental evidence for the benefits. Thanks! |
No problem 😁 Also, I made a slight judgment call for this implementation. The original Kalman filter calculation uses covariance matrices for the process noise and measurement noise, which results in matrix operations during the prediction and update steps. The standard Kalman filter equations for the prediction and update steps are as follows: Prediction step: x_pred = x Update step: y = z - x_pred where: Q is the process noise covariance matrix So the original calculation may look something like: def gradfilter_kalman(
m: nn.Module,
grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
process_noise: float = 1e-4,
measurement_noise: float = 1e-2,
lamb: float = 2.0,
) -> Dict[str, Dict[str, torch.Tensor]]:
if grads is None:
grads = {
n: {"x": torch.zeros_like(p.grad.data), "P": torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) * measurement_noise}
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
# Prediction step
x_pred = grads[n]["x"]
P_pred = grads[n]["P"] + process_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)
# Update step
y = p.grad.data - x_pred
S = P_pred + measurement_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)
K = P_pred / (P_pred + measurement_noise)
x = x_pred + K * y
P = (torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) - K) * P_pred
# Store updated state
grads[n]["x"] = x
grads[n]["P"] = P
# Apply the filtered gradient
p.grad.data = p.grad.data + x * lamb
return grads However, using covariance matrices in this context can lead to increased computational complexity and memory usage, especially for larger models with a high number of parameters. The matrix operations in the prediction and update steps would have a time complexity of O(num_parameters^2) for each parameter. To address this, I have opted to use scalar values for the process noise and measurement noise, treating them as constants across all parameters. This simplification reduces the time complexity to O(num_parameters) and avoids the need for matrix operations, making the calculation more efficient as the model size scales up. So while my simplified version may not capture the full covariance information as in the standard Kalman filter, I believe it provides a good balance between computational efficiency and the ability to filter gradients effectively. The scalar noise values still allow the filter to adapt to the characteristics of the gradients and provide smoothing. If you notice any discrepancies with the standard Kalman filter behavior, this simplification may be the reason why. However, I believe the benefits in terms of reduced time complexity and improved scalability outweigh the potential drawbacks, especially in the context of machine learning models where efficiency is crucial. Let me know if you have any further questions or if there's anything else I can clarify! If you need a visual for how this compares to an EMA, I simulated a graph for what it should look like over some fake gradients: import React from 'react';
import { LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer } from 'recharts';
const generateData = (count, min, max, noise) => {
const data = [];
let value = Math.random() * (max - min) + min;
for (let i = 0; i < count; i++) {
value += Math.random() * (max - min) * 0.1 - (max - min) * 0.05;
value = Math.max(min, Math.min(max, value));
data.push({
x: i,
y: value + Math.random() * noise - noise / 2,
});
}
return data;
};
const kalmanFilter = (data, processNoise, measurementNoise) => {
let x = data[0].y;
let p = measurementNoise;
return data.map((point) => {
const x_pred = x;
const p_pred = p + processNoise;
const y = point.y - x_pred;
const k = p_pred / (p_pred + measurementNoise);
x = x_pred + k * y;
p = (1 - k) * p_pred;
return { x: point.x, kalman: x };
});
};
const emaFilter = (data, alpha) => {
let ema = data[0].y;
return data.map((point) => {
ema = alpha * point.y + (1 - alpha) * ema;
return { x: point.x, ema };
});
};
const KalmanVsEmaPlot = () => {
const data = generateData(100, -5, 5, 2);
const kalmanData = kalmanFilter(data, 0.01, 0.1);
const emaData = emaFilter(data, 0.1);
const mergedData = data.map((point, i) => ({
x: point.x,
y: point.y,
kalman: kalmanData[i].kalman,
ema: emaData[i].ema,
}));
return (
<ResponsiveContainer width="100%" height={400}>
<LineChart data={mergedData} margin={{ top: 5, right: 30, left: 20, bottom: 5 }}>
<CartesianGrid strokeDasharray="3 3" />
<XAxis dataKey="x" />
<YAxis />
<Tooltip />
<Legend />
<Line type="monotone" dataKey="y" stroke="#8884d8" dot={false} name="Gradient" />
<Line type="monotone" dataKey="kalman" stroke="#82ca9d" name="Simplified Kalman" />
<Line type="monotone" dataKey="ema" stroke="#ff7300" name="EMA" />
</LineChart>
</ResponsiveContainer>
);
};
export default KalmanVsEmaPlot; You should be able to see how much quicker the Kalman can fit to the underlying data. EMA's lag much more. |
Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998 |
Unfortunately no. I stated earlier that this is untested. The reason for submitting the feature was due to the fact that there seemed to be some differences in the results presented based on the noise filter so I wanted to present an option that is a better noise filter than an exponential moving average. I currently don't have access to compute to run the same experiments myself so it is open for any other researchers to evaluate. |
Thanks for the response, I think it's an interesting idea and I have the compute to run it. Do you have suggestions/intuitions for choosing the parameters? (lamb, process_noise, measurement_noise etc) Or should I do a grid search in some range to start testing the filter? |
Any update here? |
Added a Kalman Filter option for the gradfilter
Motivation:
If performance scales with better noise filters, Kalman Filters outperform EMAs for filtering noise in datasets so it may offer better performance for the grokfast algorithm as well over the EMA implementation.
Note:
Has not yet been tested. Also sorry for all the annoying formatting changes. I'm using black as my auto formatter.