Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/kalman filter #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

khari998
Copy link

@khari998 khari998 commented Jun 28, 2024

Added a Kalman Filter option for the gradfilter

Motivation:
If performance scales with better noise filters, Kalman Filters outperform EMAs for filtering noise in datasets so it may offer better performance for the grokfast algorithm as well over the EMA implementation.

Note:
Has not yet been tested. Also sorry for all the annoying formatting changes. I'm using black as my auto formatter.

khari998 added 4 commits June 28, 2024 11:56
- Added Kalman based gradfilter ✅
- Added Kalman filter integration into main files ✅
- Updated README with Kalman Filter instructions and information ✅
- Changed naming convention for consistency ✅
@ironjr
Copy link
Owner

ironjr commented Jun 28, 2024

Thank you for the interesting update! I believe this could possibly be another publishable work by itself if the results are promising. I will hold this open and unmerged for now until any experimental evidence for the benefits. Thanks!

- Initialize state covariance (P) with measurement noise for better initial estimates ✅
@khari998
Copy link
Author

khari998 commented Jun 28, 2024

No problem 😁

Also, I made a slight judgment call for this implementation. The original Kalman filter calculation uses covariance matrices for the process noise and measurement noise, which results in matrix operations during the prediction and update steps. The standard Kalman filter equations for the prediction and update steps are as follows:

Prediction step:

x_pred = x
P_pred = P + Q

Update step:

y = z - x_pred
S = P_pred + R
K = P_pred * S^(-1)
x = x_pred + K * y
P = (I - K) * P_pred

where:

Q is the process noise covariance matrix
R is the measurement noise covariance matrix
S is the innovation covariance matrix
K is the Kalman gain matrix
I is the identity matrix

So the original calculation may look something like:

def gradfilter_kalman(
    m: nn.Module,
    grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
    process_noise: float = 1e-4,
    measurement_noise: float = 1e-2,
    lamb: float = 2.0,
) -> Dict[str, Dict[str, torch.Tensor]]:
    if grads is None:
        grads = {
            n: {"x": torch.zeros_like(p.grad.data), "P": torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) * measurement_noise}
            for n, p in m.named_parameters()
            if p.requires_grad and p.grad is not None
        }

    for n, p in m.named_parameters():
        if p.requires_grad and p.grad is not None:
            # Prediction step
            x_pred = grads[n]["x"]
            P_pred = grads[n]["P"] + process_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)

            # Update step
            y = p.grad.data - x_pred
            S = P_pred + measurement_noise * torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape)
            K = P_pred / (P_pred + measurement_noise)
            x = x_pred + K * y
            P = (torch.eye(p.grad.data.numel()).view(p.grad.data.shape + p.grad.data.shape) - K) * P_pred

            # Store updated state
            grads[n]["x"] = x
            grads[n]["P"] = P

            # Apply the filtered gradient
            p.grad.data = p.grad.data + x * lamb

    return grads

However, using covariance matrices in this context can lead to increased computational complexity and memory usage, especially for larger models with a high number of parameters. The matrix operations in the prediction and update steps would have a time complexity of O(num_parameters^2) for each parameter.

To address this, I have opted to use scalar values for the process noise and measurement noise, treating them as constants across all parameters. This simplification reduces the time complexity to O(num_parameters) and avoids the need for matrix operations, making the calculation more efficient as the model size scales up. So while my simplified version may not capture the full covariance information as in the standard Kalman filter, I believe it provides a good balance between computational efficiency and the ability to filter gradients effectively. The scalar noise values still allow the filter to adapt to the characteristics of the gradients and provide smoothing.

If you notice any discrepancies with the standard Kalman filter behavior, this simplification may be the reason why. However, I believe the benefits in terms of reduced time complexity and improved scalability outweigh the potential drawbacks, especially in the context of machine learning models where efficiency is crucial.

Let me know if you have any further questions or if there's anything else I can clarify!

If you need a visual for how this compares to an EMA, I simulated a graph for what it should look like over some fake gradients:

import React from 'react';
import { LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer } from 'recharts';

const generateData = (count, min, max, noise) => {
  const data = [];
  let value = Math.random() * (max - min) + min;

  for (let i = 0; i < count; i++) {
    value += Math.random() * (max - min) * 0.1 - (max - min) * 0.05;
    value = Math.max(min, Math.min(max, value));
    data.push({
      x: i,
      y: value + Math.random() * noise - noise / 2,
    });
  }

  return data;
};

const kalmanFilter = (data, processNoise, measurementNoise) => {
  let x = data[0].y;
  let p = measurementNoise;

  return data.map((point) => {
    const x_pred = x;
    const p_pred = p + processNoise;

    const y = point.y - x_pred;
    const k = p_pred / (p_pred + measurementNoise);
    x = x_pred + k * y;
    p = (1 - k) * p_pred;

    return { x: point.x, kalman: x };
  });
};

const emaFilter = (data, alpha) => {
  let ema = data[0].y;

  return data.map((point) => {
    ema = alpha * point.y + (1 - alpha) * ema;
    return { x: point.x, ema };
  });
};

const KalmanVsEmaPlot = () => {
  const data = generateData(100, -5, 5, 2);
  const kalmanData = kalmanFilter(data, 0.01, 0.1);
  const emaData = emaFilter(data, 0.1);

  const mergedData = data.map((point, i) => ({
    x: point.x,
    y: point.y,
    kalman: kalmanData[i].kalman,
    ema: emaData[i].ema,
  }));

  return (
    <ResponsiveContainer width="100%" height={400}>
      <LineChart data={mergedData} margin={{ top: 5, right: 30, left: 20, bottom: 5 }}>
        <CartesianGrid strokeDasharray="3 3" />
        <XAxis dataKey="x" />
        <YAxis />
        <Tooltip />
        <Legend />
        <Line type="monotone" dataKey="y" stroke="#8884d8" dot={false} name="Gradient" />
        <Line type="monotone" dataKey="kalman" stroke="#82ca9d" name="Simplified Kalman" />
        <Line type="monotone" dataKey="ema" stroke="#ff7300" name="EMA" />        
      </LineChart>
    </ResponsiveContainer>
  );
};

export default KalmanVsEmaPlot;

You should be able to see how much quicker the Kalman can fit to the underlying data. EMA's lag much more.

@Zhi0467
Copy link

Zhi0467 commented Jul 10, 2024

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

@khari998
Copy link
Author

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

Unfortunately no. I stated earlier that this is untested.

The reason for submitting the feature was due to the fact that there seemed to be some differences in the results presented based on the noise filter so I wanted to present an option that is a better noise filter than an exponential moving average. I currently don't have access to compute to run the same experiments myself so it is open for any other researchers to evaluate.

@Zhi0467
Copy link

Zhi0467 commented Jul 10, 2024

Hi, do you have any results on whether the Kalman filter eliminates grokking better that you can share now? @khari998

Unfortunately no. I stated earlier that this is untested.

The reason for submitting the feature was due to the fact that there seemed to be some differences in the results presented based on the noise filter so I wanted to present an option that is a better noise filter than an exponential moving average. I currently don't have access to compute to run the same experiments myself so it is open for any other researchers to evaluate.

Thanks for the response, I think it's an interesting idea and I have the compute to run it. Do you have suggestions/intuitions for choosing the parameters? (lamb, process_noise, measurement_noise etc) Or should I do a grid search in some range to start testing the filter?

@HydrogenBombaklot
Copy link

Any update here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants