Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding OOM issues #182

Open
MichaelMMeskhi opened this issue Jun 23, 2023 · 3 comments
Open

Question regarding OOM issues #182

MichaelMMeskhi opened this issue Jun 23, 2023 · 3 comments

Comments

@MichaelMMeskhi
Copy link

MichaelMMeskhi commented Jun 23, 2023

I am trying to improve up this paper https://arxiv.org/pdf/2011.00050.pdf where they optimize some subset using NTK. They optimize their loss in batches. Smaller batches for more complex architectures (e.g. Conv, Myrtle). In my case, I am unable to optimize in batches and have to load entire dataset into memory. For instance, I am optimizing a subset of 100 parameters where my kernels can be of size (100, 60000). Using a single FC NTK of width 1024, things runs fine. But when I try to use a 2 layer Conv of width 64, OOM errors occur. Are there ways to reduce memory footprint via NTK?

Their code is here kip_open_source.ipynb - Colaboratory

Thank you

@mohamad-amin
Copy link

Hey Michael, could you please explain what's 60,000 in "my kernels can be of size (100, 60000)"? And by kernels, do you mean NTKs? As far as I know, when computing the infinite-width NTKs (as it is the case in the notebook that you shared), the input width for the network doesn't affect the performance (neither computational nor memory) of computing the kernels. For instance, you can see that the kernel_fn doesn't need parameters of the net as the input.

As far as I know, there's no easy way to reduce the memory footprint of computing infinite-width NTKs, except maybe deriving the exact NTK formulas analytically and computing them directly (as they did in https://github.com/LeoYu/neural-tangent-kernel-UCI for Convolutional-NTKs) as opposed to this repo which does it compositionally (for the sake of generality).

However, I would suggest using empirical NTKs instead of infinite-width NTKs. Particularly about this work that you have suggested, if I understand things correctly, they are treating the network as the fixed object and data as the trainable parameters, as opposed to data as the fixed object and network's weights as the trainable parameters. In this case, I highly suspect that using an empirical NTK with trained weights at the end of the training procedure would produce better results than using the infinite-width NTK, as the generalization of a finite-width network at the end of (proper) training is often better than that of a corresponding infinite-width network.

If you decide to use empirical NTKs, I would again suggest using pseudo-NTK (https://proceedings.mlr.press/v202/mohamadi23a.html), which approximates empirical-NTK almost perfectly at the end of training, and is orders of magnitude cheaper, both computational and memory complexity-wise. It's shown in the paper that you can use pNTK to compute full 50,000 x 50,000 kernels on datasets like CIFAR-10 with ResNet18 network on a reasonable machine available in academia.

Let me know if it helps!

@MichaelMMeskhi
Copy link
Author

Hi @mohamad-amin thank you for your feedback. I will definitely look into that but at this moment I have to finalize the project as is.

So looking into the code better, I understand that the limitation isn't in computing k(x,x) but rather doing backprop. If I understand correctly, ntk.batch is mainly for kernel computation (forward pass). Is there anything to break up gradient calculation within NTK? If not I assume that is something to be done via JAX.

@mohamad-amin
Copy link

Hey Michael,

Unfortunately I'm not an expert on autograd, and I don't know many tricks in this regard. I just skimmed the code, and it seems like in the loss_acc_fn they use sp.linalg.solve to compute the kernel regression predictions. I'm not exactly sure how the gradient for this step is computed, but if it's taking gradient of iterative LU operations, that could require a lot of memory. (also see jax-ml/jax#1747)
I'd suggest replacing the np.linalg.solve in that function with the cholesky solve alternative (see https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_solve.html and https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_factor.html) for possible improvements both memory-wise and speed-wise.

And yes, nt.batch is for computing the NTK kenel in batches (see https://neural-tangents.readthedocs.io/en/latest/batching.html).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants